From 7456c8c6d25364a6f8f62085d2e83c88a57bb7be Mon Sep 17 00:00:00 2001 From: Wenhua Cheng Date: Thu, 2 Apr 2026 17:44:41 +0800 Subject: [PATCH 01/13] Update alg_ext.py --- auto_round/alg_ext.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/auto_round/alg_ext.py b/auto_round/alg_ext.py index 4d74f8bd9..52a11c28d 100644 --- a/auto_round/alg_ext.py +++ b/auto_round/alg_ext.py @@ -72,7 +72,7 @@ def get_abs_top_percent_mask(x: torch.Tensor, percent: float = 1.0): inv_mask (torch.BoolTensor): Inverse of mask. """ flat = x.view(-1) - k = max(1, int(flat.numel() * percent / 1000)) # 至少选1个 + k = max(1, int(flat.numel() * percent / 1000)) _, idx = torch.topk(torch.abs(flat), k) mask = torch.zeros_like(flat, dtype=torch.bool) @@ -612,7 +612,7 @@ def iterative_wls_quant_search(data, bits=4, rrmin=-1.0, rdelta=0.1, nstep=20, u # iscale_new = factor / (rmax - rmin + 1e-8) scale_new = (rmax - rmin) / factor iscale_new = get_reciprocal(scale_new) - quant_data_new = torch.clamp(torch.round(iscale_new * (data - rmin)), minq, maxq) + quant_data_new = torch.clamp(torch.round(iscale_new * (data - rmin) + v), minq, maxq) mul_weights_quant_data = weights * quant_data_new sum_l = torch.sum(mul_weights_quant_data, dim=-1, keepdim=True) From c52fa386b1d7a63fc76b38182c7b0c224a8ba1c5 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 2 Apr 2026 09:45:20 +0000 Subject: [PATCH 02/13] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- auto_round/alg_ext.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/auto_round/alg_ext.py b/auto_round/alg_ext.py index 52a11c28d..dc9136cae 100644 --- a/auto_round/alg_ext.py +++ b/auto_round/alg_ext.py @@ -72,7 +72,7 @@ def get_abs_top_percent_mask(x: torch.Tensor, percent: float = 1.0): inv_mask (torch.BoolTensor): Inverse of mask. """ flat = x.view(-1) - k = max(1, int(flat.numel() * percent / 1000)) + k = max(1, int(flat.numel() * percent / 1000)) _, idx = torch.topk(torch.abs(flat), k) mask = torch.zeros_like(flat, dtype=torch.bool) From 86447297c2c46abee521535f0ecf41f9633ec6c1 Mon Sep 17 00:00:00 2001 From: Wenhua Cheng Date: Sat, 9 May 2026 15:45:57 +0800 Subject: [PATCH 03/13] refine alg_ext code to better support torch compile --- auto_round/alg_ext.py | 611 ++++++-------------- auto_round/data_type/gguf.py | 130 ++++- auto_round/data_type/int.py | 14 +- auto_round/export/export_to_gguf/packing.py | 19 +- 4 files changed, 285 insertions(+), 489 deletions(-) diff --git a/auto_round/alg_ext.py b/auto_round/alg_ext.py index 228247a85..c8c04e5e0 100644 --- a/auto_round/alg_ext.py +++ b/auto_round/alg_ext.py @@ -118,25 +118,14 @@ def quant_tensor_sym( init_scale=None, **kwargs, ): - """Quantize and de-quantize tensor asymmetrically. full range, credit goes to llamacpp community + """Quantize and de-quantize tensor symmetrically (full-range, llama.cpp style). - Args: - tensor: Tensor containing the tensor to be quantized - bits: Number of bits for quantization (e.g., 2, 3, 4, 8) - group_size: Number of elements to share scale for quantization - v: Rounding value perturbation - min_scale: Minimum scale coefficient for tensor - max_scale: Maximum scale coefficient for tensor - tensor_min (Tensor, optional): Minimum tensor value for quantization. Defaults to None. - tensor_max (Tensor, optional): Maximum tensor value for quantization. Defaults to None. - scale_dtype: dtype of the quantized scale,as most kernels only support FP16 or FP32 - q_scale_thresh: clip the quantized scale's magnitude to this value to improve the numerical stability - - Returns: - Quantized and de-quantized tensor, scale, zero-point + ``maxq`` is computed via ``int(2.0 ** (bits - 1))`` so it stays a plain + Python int constant inside the inductor graph and Triton never tries to + lower a ``2 ** SymInt`` through ``libdevice.pow(fp32, i64)``. """ tensor, orig_shape, pad_len = reshape_pad_tensor_by_group_size(tensor, group_size) - maxq = 2 ** (bits - 1) + maxq = int(2.0 ** (bits - 1)) scale = init_scale * max_scale.unsqueeze(dim=-1) int_w = round_ste(tensor / scale + v) q = torch.clamp(int_w, -maxq, maxq - 1) @@ -149,7 +138,7 @@ def quant_tensor_sym( def qdq_mxfp(tensor, max_val, max_norm, emax, ebits, mbits): shared_exp = torch.where(max_val == 0, torch.ones_like(max_val), torch.log2(max_val)) shared_exp = torch.floor(shared_exp) - scale_emax = 2 ** (8 - 1) - 1 + scale_emax = (1 << (8 - 1)) - 1 shared_exp = (shared_exp - emax).clamp(min=-scale_emax, max=scale_emax) scale = torch.pow(2.0, shared_exp) @@ -264,7 +253,7 @@ def quant_mx( # shared_exp = torch.log2(shared_exp + FP32_MIN_NORMAL * (shared_exp == 0).type(shared_exp.dtype)) shared_exp = torch.where(max_val == 0, torch.ones_like(max_val), torch.log2(max_val)) shared_exp = floor_ste(shared_exp) - scale_emax = 2 ** (8 - 1) - 1 + scale_emax = (1 << (8 - 1)) - 1 shared_exp = (shared_exp - emax).clamp(min=-scale_emax, max=scale_emax) scale = torch.pow(2.0, shared_exp) @@ -522,443 +511,55 @@ def get_imatrix_hook(module, input, output): return hook_handles -# ---------------------------- gguf alg ---------------------------- -from auto_round.data_type.gguf import double_quant_tensor -from auto_round.export.export_to_gguf.packing import make_qx_quants - - -def make_qp_quants(nmax, data, quant_weights, v=0): - data = data.to(torch.float32) - quant_weights = quant_weights.to(torch.float32) - group_max = torch.max(data, dim=-1, keepdim=True)[0] - scale = group_max / nmax - iscale = get_reciprocal(scale) - if isinstance(v, torch.Tensor) and v.numel() != 1: - v = v.view(data.shape) - v = v.to(data.device) - - L = torch.round(iscale * data + v) - diffs = data - scale * L - best_mse = torch.sum(quant_weights * diffs * diffs, dim=-1) - - for _is in range(-9, 10): - if _is == 0: - continue - scale_is = group_max / (0.1 * _is + nmax) - iscale_is = get_reciprocal(scale_is) - - tmp_L = torch.round(iscale_is * data + v).clip(max=nmax) - diffs = data - scale_is * tmp_L - mse = torch.sum(quant_weights * diffs * diffs, dim=-1) - - replace_idx = mse < best_mse - best_mse[replace_idx] = mse[replace_idx] - iscale[replace_idx] = iscale_is[replace_idx] - - L = torch.round(iscale * data + v).clip(max=nmax) - sumlx = torch.sum(quant_weights * data * L, dim=-1) - suml2 = torch.sum(quant_weights * L * L, dim=-1) - # When suml2 is zero (all L=0 or all quant_weights=0), fall back to the - # simple max-based scale estimate to avoid NaN propagating into the GGUF file. - fallback_d = group_max.squeeze(-1) / nmax - return torch.where(suml2 > 0, sumlx / suml2, fallback_d), L +def _dq_asym_qdq(tensor, scale, wmin, bits, group_size, v=0): + """Pure asym double-quant qdq math given precomputed scale/wmin. - -# @torch._disable_dynamo() -def iterative_wls_quant_search(data, bits=4, rrmin=-1.0, rdelta=0.1, nstep=20, use_mad=False, weights=None, v=0): - """Adapted from Llamacpp. Performs iterative weighted least squares quantization search. - - Args: - data (torch.Tensor): Input tensor to quantize. - bits (int): Number of quantization bits. - rrmin (float): Initial range scaling factor. - rdelta (float): Step size for range scaling. - nstep (int): Number of search steps. - use_mad (bool): Whether to use mean absolute deviation instead of squared error. - weights (torch.Tensor): Weight matrix for each element. - - Returns: - Tuple: (Optimal scale tensor, optimal minimum value tensor) + ``maxq`` is computed via ``int(2.0 ** bits) - 1`` so that any SymInt + handling for ``bits`` does not produce a ``libdevice.pow(fp32, i64)`` call + in Triton (which lacks that overload). The final ``int(...)`` cast keeps + ``maxq`` as a Python int constant inside the compiled graph. """ - dtype = torch.float32 - data = data.to(dtype) - maxq = 2**bits - 1 - minq = 0 - weights = 1.0 if weights is None else weights.to(dtype) - - rmin = torch.min(data, dim=1, keepdim=True)[0] - rmax = torch.max(data, dim=1, keepdim=True)[0] - - sum_w = torch.sum(weights, dim=1, keepdim=True) - sum_x = torch.sum(weights * data, dim=1, keepdim=True) - - # scale = 1 / ((maxq - minq) / (rmax - rmin + 1e-8)) - scale = (rmax - rmin) / (maxq - minq) - if isinstance(v, torch.Tensor) and v.numel() > 1: - v = v.reshape(data.shape) - - iscale = get_reciprocal(scale) - # quant_data = torch.clamp(torch.round((maxq - minq) / (rmax - rmin + 1e-8) * (data - rmin)), minq, maxq) - quant_data = torch.clamp(torch.round(iscale * (data - rmin) + v), minq, maxq) - diff = scale * quant_data + rmin - data - - best_mad = torch.sum((weights * torch.abs(diff)) if use_mad else weights * diff**2, dim=1, keepdim=True) - - for is_ in range(nstep): - factor = rrmin + rdelta * is_ + maxq - minq - # iscale_new = factor / (rmax - rmin + 1e-8) - scale_new = (rmax - rmin) / factor - iscale_new = get_reciprocal(scale_new) - quant_data_new = torch.clamp(torch.round(iscale_new * (data - rmin) + v), minq, maxq) - - mul_weights_quant_data = weights * quant_data_new - sum_l = torch.sum(mul_weights_quant_data, dim=-1, keepdim=True) - sum_l2 = torch.sum(mul_weights_quant_data * quant_data_new, dim=-1, keepdim=True) - sum_xl = torch.sum(mul_weights_quant_data * data, dim=-1, keepdim=True) - - D = sum_w * sum_l2 - sum_l**2 - this_scale = (sum_w * sum_xl - sum_x * sum_l) / D - this_min = (sum_l2 * sum_x - sum_l * sum_xl) / D - this_min[this_min > 0] = 0 - this_scale[this_min > 0] = (sum_xl / sum_l2)[this_min > 0] - reverse_this_scale = get_reciprocal(this_scale) - - quant_data = torch.clamp(torch.round(reverse_this_scale * (data - this_min) + v), minq, maxq) - diff = this_scale * quant_data + this_min - data - # diff = this_scale * quant_data_new + this_min - data - mad = torch.sum((weights * torch.abs(diff)) if use_mad else weights * diff**2, dim=-1, keepdim=True) - - idx_to_replace = torch.where((mad < best_mad) & (D > 0))[0] - best_mad[idx_to_replace] = mad[idx_to_replace] - scale[idx_to_replace] = this_scale[idx_to_replace] - rmin[idx_to_replace] = this_min[idx_to_replace] - - return scale.to(torch.float32), -rmin.to(torch.float32) - - -def make_qp_new_quants(data, orig_scale, orig_mins, quant_weights, bits=4, super_bits=6, data_v=0, scale_v=0, min_v=0): - nmax = 2**super_bits - 1 - maxq = 2**bits - 1 - minq = 0 - orig_scale = orig_scale.to(torch.float32) - quant_weights = quant_weights.to(torch.float32) - group_max = torch.max(orig_scale, dim=-1, keepdim=True)[0] - s_scale = group_max / nmax - i_sscale = get_reciprocal(s_scale) - if isinstance(scale_v, torch.Tensor): - if scale_v.numel() != 1: - scale_v = scale_v.view(orig_scale.shape) - scale_v = scale_v.to(data.device) - data_v = data_v.to(data.device) - - L_scale = torch.round(i_sscale * orig_scale + scale_v) - qdq_scale = L_scale * s_scale - id_scale = get_reciprocal(qdq_scale) - id_scale = id_scale.view(-1, 1) - orig_mins = orig_mins.view(-1, 1) if orig_mins is not None and isinstance(orig_mins, torch.Tensor) else orig_mins - quant_data = torch.clamp(torch.round(id_scale * (data - orig_mins) + data_v.to(data.device)), minq, maxq) - qdq_scale = qdq_scale.view(-1, 1) - diff = qdq_scale * quant_data + orig_mins - data - best_mse = torch.sum(quant_weights * diff * diff, dim=-1) - best_mse = best_mse.view(orig_scale.shape) - best_mse = torch.sum(best_mse, dim=-1) - for _is in range(-9, 10): - if _is == 0: - continue - scale_s_is = group_max / (0.1 * _is + nmax) - iscale_s_is = get_reciprocal(scale_s_is) - - tmp_L_scale = torch.round(iscale_s_is * orig_scale + scale_v).clip(min=0, max=nmax) - qdq_scale = scale_s_is * tmp_L_scale - reverse_this_scale = get_reciprocal(qdq_scale) - reverse_this_scale = reverse_this_scale.view(-1, 1) - quant_data = torch.clamp(torch.round(reverse_this_scale * (data - orig_mins) + data_v), minq, maxq) - diffs = qdq_scale.view(-1, 1) * quant_data + orig_mins - data - mse = torch.sum(quant_weights * diffs * diffs, dim=-1) - mse = mse.view(orig_scale.shape) - mse = torch.sum(mse, dim=-1) - replace_idx = mse < best_mse - best_mse[replace_idx] = mse[replace_idx] - i_sscale[replace_idx] = iscale_s_is[replace_idx] - - L = torch.round(i_sscale * orig_scale + scale_v).clip(max=nmax) - quant_weights = torch.sum(quant_weights, dim=-1) - quant_weights = quant_weights.view(orig_scale.shape) - sumlx = torch.sum(quant_weights * orig_scale * L, dim=-1) - suml2 = torch.sum(quant_weights * L * L, dim=-1) - # When suml2 is zero, fall back to the simple max-based scale estimate - # to avoid NaN propagating into the GGUF file. - fallback_d = group_max.squeeze(-1) / nmax - return torch.where(suml2 > 0, sumlx / suml2, fallback_d), L - - -def quant_tensor_gguf_asym_dq( - tensor, - bits=4, - v=0, - min_scale=1.0, - max_scale=1.0, - scale_dtype=torch.float16, - tensor_min=None, - tensor_max=None, - q_scale_thresh=1e-5, - imatrix=None, - prev_scale=None, - prev_wmin=None, - prev_d_scale=None, - prev_d_wmin=None, - iter=0, - scale_v=0, - wmin_v=0, - **kwargs, -): - """Quantizes and dequantizes a tensor using asymmetric integer quantization for formats like Q2_K, Q4_K, and Q5_K. - Only fit for iters 0 - - Args: - tensor (torch.Tensor): Input tensor to quantize. - bits (int): Number of bits for quantization. - group_size (int): Group size for per-group quantization. - v (float): Perturbation added before rounding. - min_scale (float): Minimum allowed scale value. - max_scale (float): Maximum allowed scale value. - scale_dtype (torch.dtype): Data type for quantized scale. - tensor_min (torch.Tensor, optional): Minimum values for the tensor groups. - tensor_max (torch.Tensor, optional): Maximum values for the tensor groups. - q_scale_thresh (float): Threshold to clamp the quantized scale. - super_group_size (int): Number of groups to bundle for secondary quantization. - super_bits (int): Number of bits used in secondary quantization. - imatrix (torch.Tensor, optional): Importance matrix for weighted quantization. - - Returns: - Tuple: (Quantized-dequantized tensor, scale dictionary, zero-point dictionary) - """ - - orig_dtype = tensor.dtype - maxq = 2**bits - 1 - group_size = 16 if bits == 2 else 32 - super_bits = 4 if bits == 2 else 6 - super_group_size = 16 if bits == 2 else 8 tensor, orig_shape, pad_len = reshape_pad_tensor_by_group_size(tensor, group_size) + orig_dtype = tensor.dtype tensor = tensor.to(torch.float32) - if iter is None: - iter = 0 - if iter % 10 == 0 or iter == -1 or prev_scale is None: - if bits not in [2, 4, 5]: - raise ValueError(f"bits={bits} not supported by rtn_int_asym_dq") - quant_weights = None - if imatrix is None or (imatrix is not None and torch.sum(imatrix) == 0): - search_kwargs = { - 2: {"rmin": -0.5, "rdelta": 0.1, "nstep": 15, "use_mad": True}, - 4: {"rmin": -1, "rdelta": 0.1, "nstep": 20, "use_mad": False}, - 5: {"rmin": -0.5, "rdelta": 0.1, "nstep": 15, "use_mad": False}, - } - if bits == 2: - quant_weights = torch.abs(tensor) - elif bits == 4 or bits == 5: - sigma2 = torch.sum(tensor**2, dim=-1, keepdim=True) / 32 ##Note 32 is different from QK_K - av_x = torch.sqrt(sigma2) - quant_weights = torch.abs(tensor) + av_x - params = search_kwargs[bits] - scale, wmin = iterative_wls_quant_search( - tensor, - bits=bits, - rrmin=params["rmin"], - rdelta=params["rdelta"], - nstep=params["nstep"], - use_mad=params["use_mad"], - weights=quant_weights, - v=v, - ) - scale = scale.to(scale_dtype) - scale = torch.where(torch.abs(scale) < 1e-30, torch.zeros_like(scale), scale) - scale = scale.reshape(-1, super_group_size) - wmin = wmin.reshape(-1, super_group_size) - scale, d_scale = double_quant_tensor(scale, super_bits) - wmin = torch.where(torch.abs(wmin) < 1e-30, torch.zeros_like(wmin), wmin) - wmin, d_wmin = double_quant_tensor(wmin, super_bits) - wmin = wmin.view(-1, 1) - scale = scale.view(-1, 1) - else: - imatrix = imatrix.to(tensor.device) - search_kwargs = { - 2: {"rmin": -0.9, "rdelta": 0.05, "nstep": 36, "use_mad": False}, - 4: {"rmin": -0.9, "rdelta": 0.05, "nstep": 36, "use_mad": False}, - 5: {"rmin": -0.9, "rdelta": 0.05, "nstep": 36, "use_mad": False}, - } - - weights = imatrix.reshape(1, -1) - - weights = weights.expand(tensor.numel() // weights.numel(), -1) - quant_weights = weights.reshape(tensor.shape) - - if torch.min(quant_weights) == 0: - logger.warning_once( - "please use more data via setting `nsamples` " - "to improve accuracy as calibration activations contain 0" - ) - - zero_cnt = torch.sum(quant_weights == 0, dim=-1) - replace_index = zero_cnt > group_size // 2 - if torch.sum(replace_index) > 0: - # Fallback to no imatrix - if bits == 2: - tmp_quant_weights = torch.abs(tensor) - elif bits == 4 or bits == 5: - sigma2 = torch.sum(tensor**2, dim=-1, keepdim=True) / 32 ## Note 32 is different from QK_K - av_x = torch.sqrt(sigma2) - tmp_quant_weights = torch.abs(tensor) + av_x - quant_weights[replace_index, :] = tmp_quant_weights[replace_index, :] - mean_replace_index = (zero_cnt > 0) & (zero_cnt <= group_size // 2) - if torch.sum(mean_replace_index) > 0: - ## use mean values to fill zero values - tmp_quant_weights = torch.sum(quant_weights, dim=-1) / (quant_weights.shape[1] - zero_cnt) - tmp_quant_weights = tmp_quant_weights.view(-1, 1).expand(-1, quant_weights.shape[1]) - quant_weights[mean_replace_index, :] = tmp_quant_weights[mean_replace_index, :] - - params = search_kwargs[bits] - - scale, wmin_0 = iterative_wls_quant_search( - tensor, - bits=bits, - rrmin=params["rmin"], - rdelta=params["rdelta"], - nstep=params["nstep"], - use_mad=params["use_mad"], - weights=quant_weights, - v=v, - ) - scale = scale.to(scale_dtype) - scale = torch.where(torch.abs(scale) < 1e-30, torch.zeros_like(scale), scale) - nmax = 2**super_bits - 1 - scale = scale.reshape(-1, super_group_size) - wmin = wmin_0.reshape(-1, super_group_size) - sum_quant_weights = quant_weights.sum(-1, keepdim=True).reshape(-1, super_group_size) - - d_scale, q_scale = make_qp_new_quants(tensor, scale, wmin, quant_weights, bits, super_bits, data_v=v) - d_scale = d_scale.unsqueeze(-1) - - d_wmin, q_wmin = make_qp_quants(nmax, wmin, sum_quant_weights, v=wmin_v) - - d_wmin = d_wmin.unsqueeze(-1) - scale = (d_scale * q_scale).view(-1, 1) - wmin = (d_wmin * q_wmin).view(-1, 1) - else: - scale = prev_scale.detach() - d_scale = prev_d_scale.detach() - wmin = prev_wmin.detach() - d_wmin = prev_d_wmin.detach() + maxq = int(2.0 ** bits) - 1 inverse_scale = get_reciprocal(scale) - int_w = torch.clamp(round_ste((tensor + wmin) * inverse_scale + v), 0, maxq) - qdq_result = (scale * int_w - wmin).to(orig_dtype) - qdq_result = revert_tensor_by_pad(qdq_result, orig_shape=orig_shape, pad_len=pad_len) - return qdq_result, {"scale": scale, "d_scale": d_scale}, {"wmin": wmin, "d_wmin": d_wmin} + qdq = (scale * int_w - wmin).to(orig_dtype) + qdq = revert_tensor_by_pad(qdq, orig_shape=orig_shape, pad_len=pad_len) + return qdq -def quant_tensor_gguf_sym_dq( - tensor, - bits=3, - v=0, - imatrix=None, - prev_scale=None, - prev_d_scale=None, - iter=0, - **kwargs, -): - """Quantize and de-quantize tensor asymmetrically. For Q3_K, Q6_K. +def _dq_sym_qdq(tensor, scale, bits, v=0): + """Pure sym double-quant qdq math given precomputed scale. - Args: - tensor: Tensor containing the tensor to be quantized - bits: Number of bits for quantization (e.g., 2, 3, 4, 8) - group_size: Number of elements to share scale for quantization - v: Rounding value perturbation - min_scale: Minimum scale coefficient for tensor - max_scale: Maximum scale coefficient for tensor - tensor_min (Tensor, optional): Minimum tensor value for quantization. Defaults to None. - tensor_max (Tensor, optional): Maximum tensor value for quantization. Defaults to None. - scale_dtype: dtype of the quantized scale,as most kernels only support FP16 or FP32, while this value is import - q_scale_thresh: clip the quantized scale's magnitude to this value to improve the numerical stability - - Returns: - Quantized and de-quantized tensor, scale, zero-point + ``maxq`` is computed via float ``2.0 ** (bits - 1)`` then cast to + ``int`` to avoid SymInt-driven shifts being lowered through + ``libdevice.pow``. """ - from auto_round.export.export_to_gguf.config import GGML_QUANT_SIZES, K_SCALE_SIZE, QK_K - - if bits not in [3, 6]: - raise KeyError(f"bits={bits} is not supported by gguf_int_sym_dq, please check.") + from auto_round.export.export_to_gguf.config import GGML_QUANT_SIZES, QK_K - maxq = 2 ** (bits - 1) group_size = 16 - super_bits = 6 if bits == 3 else 8 super_group_size = 16 - + maxq = int(2.0 ** (bits - 1)) tensor, orig_shape, pad_len = reshape_pad_tensor_by_group_size(tensor, group_size) - ggml_type = f"q{bits}_k" - block_size, type_size = GGML_QUANT_SIZES[ggml_type] orig_dtype = tensor.dtype - tensor = tensor.to(torch.float32) + ggml_type = f"q{bits}_k" + block_size, _ = GGML_QUANT_SIZES[ggml_type] n_blocks = tensor.nelement() // block_size - # (nb, 16, 16) - # tensor = tensor.reshape(n_blocks, super_group_size, QK_K // super_group_size) - - if iter is None: - iter = 0 - if iter % 10 == 0 or iter == -1 or prev_scale is None: - if imatrix is None or (imatrix is not None and torch.sum(imatrix) == 0): - if bits == 3: - from auto_round.export.export_to_gguf.packing import make_q3_quants - - scale, int_w = make_q3_quants(tensor, bits=bits, do_rmse=True) - elif bits == 6: - scale, int_w = make_qx_quants(tensor, bits=bits, rmse_type=1, qw=None) - else: - imatrix = imatrix.to(tensor.device) - - weights = imatrix.reshape(1, -1) - weights = weights.expand(tensor.numel() // weights.numel(), -1) - quant_weights = weights.reshape(tensor.shape) - if torch.min(quant_weights) == 0: - logger.warning_once( - "please use more data via setting `nsamples` " - "to improve accuracy as calibration activations contain 0" - ) - zero_cnt = torch.sum(quant_weights == 0, dim=-1) - replace_index = zero_cnt > group_size // 2 - if torch.sum(replace_index) > 0: - if bits == 6: - quant_weights[replace_index] = tensor[replace_index] * tensor[replace_index] - else: - sigma2 = 2 * torch.sum(torch.pow(tensor, 2), dim=-1, keepdim=True) / QK_K - tmp_quant_weights = torch.sqrt(sigma2 + tensor * tensor) - quant_weights[replace_index] = tmp_quant_weights[replace_index] - mean_replace_index = (zero_cnt > 0) & (zero_cnt <= group_size // 2) - if torch.sum(mean_replace_index) > 0: - ## use mean values to fill zero values - tmp_quant_weights = torch.sum(quant_weights, dim=-1) / (quant_weights.shape[-1] - zero_cnt) - tmp_quant_weights = ( - tmp_quant_weights.view(-1, 1).expand(-1, quant_weights.shape[1]).reshape(tensor.shape) - ) - quant_weights[mean_replace_index] = tmp_quant_weights[mean_replace_index] - - # scale, int_w = make_qx_quants(tensor, bits=bits, rmse_type=1, qw=quant_weights, v=v) - scale, int_w = make_qx_quants(tensor, bits=bits, rmse_type=1, qw=quant_weights) - scale = torch.where(torch.abs(scale) < 1e-30, torch.zeros_like(scale), scale) - # conduct double quant - d_scale, q_scale = make_qp_new_quants(tensor, scale, 0, quant_weights, bits, super_bits, data_v=v) - d_scale = d_scale.unsqueeze(-1) - scale = (d_scale * q_scale).unsqueeze(-1) + tensor = tensor.reshape(n_blocks, super_group_size, QK_K // super_group_size) + if isinstance(v, torch.Tensor): + v_r, _, _ = reshape_pad_tensor_by_group_size(v, group_size) + v_r = v_r.reshape(n_blocks, super_group_size, QK_K // super_group_size) else: - scale = prev_scale.detach() - d_scale = prev_d_scale.detach() - zp = torch.full_like(scale, maxq) # pylint: disable=E1130 + v_r = v + zp = torch.full_like(scale, maxq) inverse_scale = get_reciprocal(scale) - int_w = round_ste(tensor * inverse_scale + v).clip(-maxq, maxq - 1) + maxq - qdq_result = (scale * (int_w - zp)).to(orig_dtype) - qdq_result = revert_tensor_by_pad(qdq_result, orig_shape=orig_shape, pad_len=pad_len) - - return qdq_result, {"scale": scale, "d_scale": d_scale}, zp + int_w = round_ste(tensor * inverse_scale + v_r).clip(-maxq, maxq - 1) + maxq + qdq = (scale * (int_w - zp)).to(orig_dtype) + qdq = revert_tensor_by_pad(qdq, orig_shape=orig_shape, pad_len=pad_len) + return qdq class DQWrapperLinear(WrapperLinear): @@ -1016,10 +617,24 @@ def _init_tuning_params_and_quant_func(self): """ super()._init_tuning_params_and_quant_func() p_dtype = torch.float32 + # ``search_func`` is kept un-compiled because it contains data-dependent + # control flow (imatrix branches, iterative searches), while + # ``weight_quant_func`` is the compilable pure-math part. + self.search_func = None + self._dq_kind = None + self._is_dq_path = False if hasattr(self.orig_layer, "super_group_size") and self.orig_layer.super_group_size is not None: - self.weight_quant_func = ( - quant_tensor_gguf_asym_dq if self.orig_layer.data_type == "int_asym_dq" else quant_tensor_gguf_sym_dq - ) + self._is_dq_path = True + from auto_round.data_type.gguf import search_gguf_scale_min_asym, search_gguf_scale_min_sym + + if self.orig_layer.data_type == "int_asym_dq": + self.search_func = search_gguf_scale_min_asym + self.weight_quant_func = _dq_asym_qdq + self._dq_kind = "asym" + else: + self.search_func = search_gguf_scale_min_sym + self.weight_quant_func = _dq_sym_qdq + self._dq_kind = "sym" elif self.orig_layer.sym: from auto_round.data_type.int import quant_tensor_sym @@ -1030,10 +645,13 @@ def _init_tuning_params_and_quant_func(self): self.weight_quant_func = quant_tensor_asym self.data_type = self.orig_layer.data_type if self.enable_act_quant: + from auto_round.data_type.gguf import ( + quant_tensor_gguf_asym_dq as _gguf_asym_dq, + quant_tensor_gguf_sym_dq as _gguf_sym_dq, + ) + self.act_quant_func = ( - quant_tensor_gguf_asym_dq - if self.orig_layer.act_data_type == "int_asym_dq" - else quant_tensor_gguf_sym_dq + _gguf_asym_dq if self.orig_layer.act_data_type == "int_asym_dq" else _gguf_sym_dq ) if self.enable_torch_compile: self.act_quant_func = compile_func(self.act_quant_func, self.device) @@ -1042,7 +660,58 @@ def _init_tuning_params_and_quant_func(self): if self.enable_torch_compile: self.weight_quant_func = compile_func(self.weight_quant_func, self.device) - def _qdq_weight(self, value, min_scale, max_scale, scale_v=None, wmin_v=None, iter=None): + @torch.no_grad() + def _run_search(self, weight, v): + """Run the per-format scale/wmin search separately from the quant func. + + Uses the search routines from ``auto_round.data_type.gguf`` and forwards + the tuning perturbation ``v``. Returns the parameters to feed into the + (compilable) ``weight_quant_func``. + """ + from auto_round.data_type.gguf import double_quant_tensor_sym_rtn + from auto_round.export.export_to_gguf.config import GGML_QUANT_SIZES, QK_K + + bits = self.orig_layer.bits + scale_dtype = self.orig_layer.scale_dtype + imatrix = getattr(self.orig_layer, "imatrix", None) + + if self._dq_kind == "asym": + group_size = 16 if bits == 2 else 32 + t, _, _ = reshape_pad_tensor_by_group_size(weight.to(torch.float32), group_size) + v_r = v + if isinstance(v, torch.Tensor): + v_r, _, _ = reshape_pad_tensor_by_group_size(v, group_size) + scale, wmin, d_scale, d_wmin = self.search_func( + t, + bits=bits, + scale_dtype=scale_dtype, + imatrix=imatrix, + split_num=1, + v=v_r, + ) + return {"scale": scale, "wmin": wmin, "d_scale": d_scale, "d_wmin": d_wmin} + + # sym path + group_size = 16 + super_group_size = 16 + t, _, _ = reshape_pad_tensor_by_group_size(weight.to(torch.float32), group_size) + ggml_type = f"q{bits}_k" + block_size, _ = GGML_QUANT_SIZES[ggml_type] + n_blocks = t.nelement() // block_size + t = t.reshape(n_blocks, super_group_size, QK_K // super_group_size) + v_r = v + if isinstance(v, torch.Tensor): + v_r, _, _ = reshape_pad_tensor_by_group_size(v, group_size) + v_r = v_r.reshape(n_blocks, super_group_size, QK_K // super_group_size) + super_bits = 6 if bits == 3 else 8 + scale = self.search_func(t, bits, imatrix, scale_dtype, split_num=1, v=v_r) + scale = scale.to(scale_dtype) + scale = torch.where(torch.abs(scale) < 1e-30, torch.zeros_like(scale), scale) + scale, d_scale = double_quant_tensor_sym_rtn(scale, super_bits) + scale = scale.unsqueeze(-1) + return {"scale": scale, "d_scale": d_scale} + + def _qdq_weight(self, value, min_scale, max_scale, scale_v=None, iter=None): """Quantizes and dequantizes weights with tuning parameters. Args: @@ -1063,6 +732,57 @@ def _qdq_weight(self, value, min_scale, max_scale, scale_v=None, wmin_v=None, it if isinstance(self.orig_layer, transformers.pytorch_utils.Conv1D): weight = weight.t() + if self._is_dq_path: + # Split search (data-dependent, un-compiled) from quant math (compilable). + iter_v = self.cur_iter if (iter is None and hasattr(self, "cur_iter")) else iter + if iter_v is None: + iter_v = 0 + need_search = (iter_v % 10 == 0) or (iter_v == -1) or (self.prev_scale is None) + if need_search: + params = self._run_search(weight, value) + self.prev_scale = params["scale"] + self.prev_d_scale = params["d_scale"] + if self._dq_kind == "asym": + self.prev_wmin = params["wmin"] + self.prev_d_wmin = params["d_wmin"] + else: + params = { + "scale": self.prev_scale.detach(), + "d_scale": self.prev_d_scale.detach(), + } + if self._dq_kind == "asym": + params["wmin"] = self.prev_wmin.detach() + params["d_wmin"] = self.prev_d_wmin.detach() + + bits = self.orig_layer.bits + if self._dq_kind == "asym": + group_size = 16 if bits == 2 else 32 + weight_q = self.weight_quant_func( + weight, + params["scale"], + params["wmin"], + bits, + group_size, + v=value, + ) + scale_out = {"scale": params["scale"], "d_scale": params["d_scale"]} + zp_out = {"wmin": params["wmin"], "d_wmin": params["d_wmin"]} + else: + weight_q = self.weight_quant_func( + weight, + params["scale"], + bits, + v=value, + ) + scale_out = {"scale": params["scale"], "d_scale": params["d_scale"]} + zp_out = torch.full_like(params["scale"], int(2.0 ** (bits - 1))) + + weight_q = weight_q.to(weight.dtype) + if isinstance(self.orig_layer, transformers.pytorch_utils.Conv1D): + weight_q = weight_q.t() + return weight_q, scale_out, zp_out + + # Non-dq path: preserve original behavior. quant_kwargs = {} if hasattr(self.orig_layer, "super_bits"): quant_kwargs["super_bits"] = self.orig_layer.super_bits @@ -1083,11 +803,6 @@ def _qdq_weight(self, value, min_scale, max_scale, scale_v=None, wmin_v=None, it prev_wmin=self.prev_wmin, prev_d_scale=self.prev_d_scale, prev_d_wmin=self.prev_d_wmin, - imatrix=self.orig_layer.imatrix if hasattr(self.orig_layer, "imatrix") else None, - iter=self.cur_iter if (iter is None and hasattr(self, "cur_iter")) else iter, - # scale_v=self.scale_v if scale_v is None else scale_v, - # wmin_v=self.wmin_v if wmin_v is None else wmin_v, - # xtx=self.orig_layer.xtx if hasattr(self.orig_layer, "xtx") else None, **quant_kwargs, ) weight_q = weight_q.to(weight.dtype) diff --git a/auto_round/data_type/gguf.py b/auto_round/data_type/gguf.py index 1a419d880..453f6eb0e 100644 --- a/auto_round/data_type/gguf.py +++ b/auto_round/data_type/gguf.py @@ -59,7 +59,7 @@ def quant_tensor_sym_dq( """ tensor, orig_shape, pad_len = reshape_pad_tensor_by_group_size(tensor, group_size) - maxq = 2 ** (bits - 1) + maxq = int(2.0 ** (bits - 1)) if tensor_min is None or tensor_max is None: wmin_tmp = torch.clamp(tensor.min(-1)[0], max=0) wmax_tmp = torch.clamp(tensor.max(-1)[0], min=0) @@ -78,7 +78,7 @@ def quant_tensor_sym_dq( scale = scale.view(-1, 1) zp = torch.full_like(scale, maxq) # pylint: disable=E1130 int_w = round_ste(tensor * get_reciprocal(scale) + v) - q = torch.clamp(int_w + zp, 0, 2**bits - 1) + q = torch.clamp(int_w + zp, 0, int(2.0 ** bits) - 1) qdq_result = (scale * (q - zp)).to(tensor.dtype) qdq_result = revert_tensor_by_pad(qdq_result, orig_shape=orig_shape, pad_len=pad_len) return qdq_result, {"scale": scale, "d_scale": d_scale}, zp @@ -116,7 +116,7 @@ def quant_tensor_asym_float_zp( Quantized and de-quantized tensor, scale, zero-point """ tensor, orig_shape, pad_len = reshape_pad_tensor_by_group_size(tensor, group_size) - maxq = 2**bits - 1 + maxq = int(2.0 ** bits) - 1 if tensor_min is None or tensor_max is None: wmin_tmp = torch.clamp(tensor.min(-1)[0], max=0) wmax_tmp = torch.clamp(tensor.max(-1)[0], min=0) @@ -141,10 +141,69 @@ def quant_tensor_asym_float_zp( return qdq_result, scale, zp +@register_dtype("opt_rtn_int_asym_float_zp") +def quant_tensor_asym_float_zp_rtn( + tensor, + bits=4, + group_size=-1, + scale_dtype=torch.float16, + tensor_min=None, + tensor_max=None, + q_scale_thresh=1e-5, + **kwargs, +): + """RTN (round-to-nearest) version of asymmetric float-zp quantization. + + Optimized with in-place operations to minimize temporary allocations. + Does not support per-element perturbation ``v`` or ``min_scale``/``max_scale`` + tensor scaling (pure RTN). + + Args: + tensor: Tensor to quantize. + bits: Number of bits for quantization. + group_size: Number of elements sharing one scale. + scale_dtype: dtype of the quantized scale. + tensor_min/tensor_max: Optional pre-computed per-group min/max. + q_scale_thresh: Lower clip for scale magnitude. + + Returns: + Tuple of (qdq tensor, scale, zero-point). + """ + tensor, orig_shape, pad_len = reshape_pad_tensor_by_group_size(tensor, group_size) + orig_dtype = tensor.dtype + maxq = int(2.0 ** bits) - 1 + + if tensor_min is None or tensor_max is None: + wmin = torch.clamp_max(tensor.min(-1)[0], 0) + wmax = torch.clamp_min(tensor.max(-1)[0], 0) + else: + wmin = tensor_min + wmax = tensor_max + + # scale = (wmax - wmin) / maxq + scale = (wmax - wmin).div_(maxq).to(scale_dtype) + scale.clamp_(min=q_scale_thresh) + + # zp = -wmin / scale + zp = (-wmin).div_(scale) + + scale = scale.unsqueeze(-1) + zp = zp.unsqueeze(-1) + + # Cast to float32 for stable inplace math, then back to original dtype + if tensor.dtype != torch.float32: + tensor = tensor.float() + inverse_scale = get_reciprocal(scale) + tensor.mul_(inverse_scale).round_().add_(zp).clamp_(0, maxq).sub_(zp).mul_(scale) + tensor = tensor.to(orig_dtype) + tensor = revert_tensor_by_pad(tensor, orig_shape=orig_shape, pad_len=pad_len) + return tensor, scale, zp + + ## the values should be positive def double_quant_tensor(tensor, bits): tensor = tensor.to(torch.float32) # Ensure tensor is in float32 for precision - maxq = 2**bits - 1 + maxq = int(2.0 ** bits) - 1 wmax = torch.clamp(tensor.max(-1)[0], min=0) scale = wmax / maxq scale = scale.view(-1, 1) @@ -156,7 +215,7 @@ def double_quant_tensor(tensor, bits): def double_quant_tensor_sym(tensor, bits): tensor = tensor.to(torch.float32) # Ensure tensor is in float32 for precision - maxq = 2 ** (bits - 1) + maxq = int(2.0 ** (bits - 1)) imax = abs(tensor).argmax(axis=-1, keepdims=True) wmax = torch.take_along_dim(tensor, imax, dim=-1) scale = wmax / -maxq @@ -175,7 +234,7 @@ def double_quant_tensor_sym_rtn(tensor, bits): if tensor.dtype != torch.float32: tensor = tensor.float() # .float() creates a copy if needed - maxq = 2 ** (bits - 1) + maxq = int(2.0 ** (bits - 1)) # Compute absolute max along last dim # abs_() is inplace @@ -196,14 +255,14 @@ def double_quant_tensor_sym_rtn(tensor, bits): return tensor, scale -def make_qp_quants(nmax, data, quant_weights): +def make_qp_quants(nmax, data, quant_weights, v=0): data = data.to(torch.float32) quant_weights = quant_weights.to(torch.float32) group_max = torch.max(data, dim=-1, keepdim=True)[0] scale = group_max / nmax iscale = get_reciprocal(scale) - L = torch.round(iscale * data) + L = torch.round(iscale * data + v) diffs = data - scale * L best_mse = torch.sum(quant_weights * diffs * diffs, dim=-1) @@ -213,7 +272,7 @@ def make_qp_quants(nmax, data, quant_weights): scale_is = group_max / (0.1 * _is + nmax) iscale_is = get_reciprocal(scale_is) - tmp_L = torch.round(iscale_is * data).clip(max=nmax) + tmp_L = torch.round(iscale_is * data + v).clip(max=nmax) diffs = data - scale_is * tmp_L mse = torch.sum(quant_weights * diffs * diffs, dim=-1) @@ -221,7 +280,7 @@ def make_qp_quants(nmax, data, quant_weights): best_mse[replace_idx] = mse[replace_idx] iscale[replace_idx] = iscale_is[replace_idx] - L = torch.round(iscale * data).clip(max=nmax) + L = torch.round(iscale * data + v).clip(max=nmax) sumlx = torch.sum(quant_weights * data * L, dim=-1) suml2 = torch.sum(quant_weights * L * L, dim=-1) # @@ -285,7 +344,7 @@ def quant_tensor_asym_dq( Quantized and de-quantized tensor, scale, zero-point """ tensor, orig_shape, pad_len = reshape_pad_tensor_by_group_size(tensor, group_size) - maxq = 2**bits - 1 + maxq = int(2.0 ** bits) - 1 if tensor_min is None or tensor_max is None: wmin_tmp = torch.clamp(tensor.min(-1)[0], max=0) wmax_tmp = torch.clamp(tensor.max(-1)[0], min=0) @@ -304,7 +363,7 @@ def quant_tensor_asym_dq( wmin = -wmin # pylint: disable=E1130 wmin = wmin.view(-1, super_group_size) - ##conduct double quant + #Conduct double quant scale, d_scale = double_quant_tensor(scale, super_bits) wmin, d_wmin = double_quant_tensor(wmin, super_bits) @@ -356,7 +415,7 @@ def _imatrix_handle_zero(imatrix: Union[torch.Tensor, float], weight: torch.Tens @torch.inference_mode() -def search_gguf_scale_min_asym(tensor, bits=4, scale_dtype=torch.float16, imatrix=None, split_num=1): +def search_gguf_scale_min_asym(tensor, bits=4, scale_dtype=torch.float16, imatrix=None, split_num=1, v=0): super_bits = 4 if bits == 2 else 6 super_group_size = 16 if bits == 2 else 8 @@ -383,6 +442,7 @@ def search_gguf_scale_min_asym(tensor, bits=4, scale_dtype=torch.float16, imatri use_mad=params["use_mad"], weights=quant_weights, split_num=split_num, + v=v, ) scale = scale.to(scale_dtype) scale = torch.where(torch.abs(scale) < 1e-30, torch.zeros_like(scale), scale) @@ -426,10 +486,11 @@ def search_gguf_scale_min_asym(tensor, bits=4, scale_dtype=torch.float16, imatri nstep=params["nstep"], use_mad=params["use_mad"], weights=quant_weights, + v=v, ) scale = scale.to(scale_dtype) scale = torch.where(torch.abs(scale) < 1e-30, torch.zeros_like(scale), scale) - nmax = 2**super_bits - 1 + nmax = int(2.0 ** super_bits) - 1 scale = scale.reshape(-1, super_group_size) wmin = wmin_0.reshape(-1, super_group_size) sum_quant_weights = quant_weights.sum(-1, keepdim=True).reshape(-1, super_group_size) @@ -464,7 +525,6 @@ def quant_tensor_gguf_asym_dq( Args: tensor (torch.Tensor): Input tensor to quantize. bits (int): Number of bits for quantization. - v (float): Perturbation added before rounding. scale_dtype (torch.dtype): Data type for quantized scale. imatrix (torch.Tensor, optional): Importance matrix for weighted quantization. @@ -474,7 +534,7 @@ def quant_tensor_gguf_asym_dq( if bits not in [2, 4, 5]: raise ValueError(f"bits={bits} not supported by rtn_int_asym_dq") orig_dtype = tensor.dtype - maxq = 2**bits - 1 + maxq = int(2.0 ** bits) - 1 group_size = 16 if bits == 2 else 32 split_num = 1 @@ -497,13 +557,14 @@ def quant_tensor_gguf_asym_dq( # TODO consolidate iterative_wls_quant_search_chunk and non-chunk def iterative_wls_quant_search_chunk( - data, bits=4, rrmin=-1.0, rdelta=0.1, nstep=20, use_mad=False, weights=None, split_num=1 + data, bits=4, rrmin=-1.0, rdelta=0.1, nstep=20, use_mad=False, weights=None, split_num=1, v=0 ): dtype = torch.float32 data = data.to(dtype) - maxq = 2**bits - 1 + maxq = int(2.0 ** bits) - 1 minq = 0 weights = 1.0 if weights is None else weights.to(dtype) + v_is_tensor = isinstance(v, torch.Tensor) results_scale = [] results_rmin = [] @@ -514,6 +575,7 @@ def iterative_wls_quant_search_chunk( end = min(start + chunk_size, data.shape[0]) chunk = data[start:end] chunk_weights = weights if isinstance(weights, float) else weights[start:end] + v_chunk = v[start:end] if v_is_tensor else v # Pre-allocate reusable buffers to avoid new allocations tmp = torch.empty_like(chunk) @@ -528,8 +590,10 @@ def iterative_wls_quant_search_chunk( scale = (rmax - rmin) / (maxq - minq) iscale = get_reciprocal(scale) - # tmp = (chunk - rmin) * iscale + # tmp = (chunk - rmin) * iscale + v tmp.copy_(chunk).sub_(rmin).mul_(iscale) + if v_is_tensor or v != 0: + tmp.add_(v_chunk) # quant_data = round(tmp).clamp_() torch.round(tmp, out=quant_data) @@ -550,8 +614,10 @@ def iterative_wls_quant_search_chunk( scale_new = (rmax - rmin) / factor iscale_new = get_reciprocal(scale_new) - # tmp = (chunk - rmin) * iscale_new + # tmp = (chunk - rmin) * iscale_new + v tmp.copy_(chunk).sub_(rmin).mul_(iscale_new) + if v_is_tensor or v != 0: + tmp.add_(v_chunk) torch.round(tmp, out=quant_data) quant_data.clamp_(minq, maxq) @@ -575,8 +641,10 @@ def iterative_wls_quant_search_chunk( reverse_this_scale = get_reciprocal(this_scale) - # tmp = (chunk - this_min) * reverse_this_scale + # tmp = (chunk - this_min) * reverse_this_scale + v tmp.copy_(chunk).sub_(this_min).mul_(reverse_this_scale) + if v_is_tensor or v != 0: + tmp.add_(v_chunk) torch.round(tmp, out=quant_data) quant_data.clamp_(minq, maxq) @@ -608,7 +676,7 @@ def iterative_wls_quant_search_chunk( def iterative_wls_quant_search( - data, bits=4, rrmin=-1.0, rdelta=0.1, nstep=20, use_mad=False, weights=None, split_num=1 + data, bits=4, rrmin=-1.0, rdelta=0.1, nstep=20, use_mad=False, weights=None, split_num=1, v=0 ): """Adapted from Llamacpp. Performs iterative weighted least squares quantization search. @@ -620,6 +688,8 @@ def iterative_wls_quant_search( nstep (int): Number of search steps. use_mad (bool): Whether to use mean absolute deviation instead of squared error. weights (torch.Tensor): Weight matrix for each element. + v: Optional rounding perturbation (scalar or tensor with the same + shape as ``data``) for AutoRound-style tuning. Returns: Tuple: (Optimal scale tensor, optimal minimum value tensor) @@ -636,19 +706,22 @@ def iterative_wls_quant_search( use_mad=use_mad, weights=weights, split_num=split_num, + v=v, ) @torch.inference_mode() -def search_gguf_scale_min_sym(tensor, bits, imatrix, scale_dtype, split_num): +def search_gguf_scale_min_sym(tensor, bits, imatrix, scale_dtype, split_num, v=0): if imatrix is None or (imatrix is not None and torch.sum(imatrix) == 0): if bits == 3: # Note: make_q3_quants does not support split_num/chunking; # 3-bit quantization is performed in a single chunk. - scale, int_w = make_q3_quants(tensor, bits=bits, do_rmse=True) + scale, int_w = make_q3_quants(tensor, bits=bits, do_rmse=True, v=v) # scale, int_w = make_qx_quants(tensor, bits=bits, rmse_type=1, qw=None) elif bits == 6: - scale, int_w = make_qx_quants_chunk(tensor, bits=bits, rmse_type=1, qw=None, split_num=split_num) + scale, int_w = make_qx_quants_chunk( + tensor, bits=bits, rmse_type=1, qw=None, split_num=split_num, v=v + ) else: imatrix = imatrix.to(tensor.device) weights = imatrix.reshape(1, -1) @@ -657,7 +730,9 @@ def search_gguf_scale_min_sym(tensor, bits, imatrix, scale_dtype, split_num): quant_weights = _imatrix_handle_zero(quant_weights, tensor, bits) - scale, int_w = make_qx_quants_chunk(tensor, bits=bits, rmse_type=1, qw=quant_weights, split_num=split_num) + scale, int_w = make_qx_quants_chunk( + tensor, bits=bits, rmse_type=1, qw=quant_weights, split_num=split_num, v=v + ) if split_num > 1: clear_memory(device_list=[tensor.device]) return scale @@ -679,7 +754,6 @@ def quant_tensor_gguf_sym_dq( Args: tensor: Tensor containing the tensor to be quantized bits: Number of bits for quantization (e.g., 2, 3, 4, 8) - v: Rounding value perturbation min_scale: Minimum scale coefficient for tensor max_scale: Maximum scale coefficient for tensor tensor_min (Tensor, optional): Minimum tensor value for quantization. Defaults to None. @@ -696,7 +770,7 @@ def quant_tensor_gguf_sym_dq( if bits not in [3, 6]: raise KeyError(f"bits={bits} is not supported by gguf_int_sym_dq, please check.") - maxq = 2 ** (bits - 1) + maxq = int(2.0 ** (bits - 1)) group_size = 16 tensor, orig_shape, pad_len = reshape_pad_tensor_by_group_size(tensor, group_size) orig_dtype = tensor.dtype diff --git a/auto_round/data_type/int.py b/auto_round/data_type/int.py index 73645dd60..4a856a375 100644 --- a/auto_round/data_type/int.py +++ b/auto_round/data_type/int.py @@ -23,7 +23,7 @@ def search_scales(data: torch.Tensor, bits: int, qw: Union[None, torch.Tensor, float] = None) -> torch.Tensor: # Maximum absolute value for symmetric quantization - nmax = 1 << (bits - 1) # equivalent to pow(2, bits-1) + nmax = int(2.0 ** (bits - 1)) # Find per-group max along the last dimension imax = torch.abs(data).argmax(dim=-1, keepdim=True) @@ -103,7 +103,7 @@ def quant_tensor_opt_rtn_sym(tensor, bits=4, group_size=-1, v=0, q_scale_thresh= from auto_round.data_type.gguf import _imatrix_handle_zero tensor, orig_shape, pad_len = reshape_pad_tensor_by_group_size(tensor, group_size) - maxq = 2 ** (bits - 1) + maxq = int(2.0 ** (bits - 1)) if imatrix is None: imatrix = 1.0 else: @@ -146,7 +146,7 @@ def quant_tensor_rtn_sym( Quantized and de-quantized tensor, scale, zero-point """ tensor, orig_shape, pad_len = reshape_pad_tensor_by_group_size(tensor, group_size) - maxq = 2 ** (bits - 1) + maxq = int(2.0 ** (bits - 1)) wmin_tmp = torch.clamp(tensor.min(-1)[0], max=0) wmax_tmp = torch.clamp(tensor.max(-1)[0], min=0) @@ -195,7 +195,7 @@ def quant_tensor_sym( """ tensor, orig_shape, pad_len = reshape_pad_tensor_by_group_size(tensor, group_size) - maxq = 2 ** (bits - 1) + maxq = int(2.0 ** (bits - 1)) if tensor_min is None or tensor_max is None: wmin_tmp = torch.clamp(tensor.min(-1)[0], max=0) wmax_tmp = torch.clamp(tensor.max(-1)[0], min=0) @@ -248,7 +248,7 @@ def quant_tensor_asym( Quantized and de-quantized tensor, scale, zero-point """ tensor, orig_shape, pad_len = reshape_pad_tensor_by_group_size(tensor, group_size) - maxq = 2**bits - 1 + maxq = int(2.0 ** bits) - 1 if tensor_min is None or tensor_max is None: wmin_tmp = torch.clamp(tensor.min(-1)[0], max=0) wmax_tmp = torch.clamp(tensor.max(-1)[0], min=0) @@ -305,7 +305,7 @@ def quant_tensor_sym_gptq( Quantized and de-quantized tensor, scale, zero-point """ tensor, orig_shape, pad_len = reshape_pad_tensor_by_group_size(tensor, group_size) - maxq = 2**bits - 1 + maxq = int(2.0 ** bits) - 1 if tensor_min is None or tensor_max is None: wmin_tmp = torch.clamp(tensor.min(-1)[0], max=0) wmax_tmp = torch.clamp(tensor.max(-1)[0], min=0) @@ -368,7 +368,7 @@ def quant_tensor_asym_wo_round( Quantized and de-quantize tensor, scale, zero-point """ tensor, orig_shape, pad_len = reshape_pad_tensor_by_group_size(tensor, group_size) - maxq = 2**bits - 1 + maxq = int(2.0 ** bits) - 1 if tensor_min is None or tensor_max is None: wmin_tmp = torch.clamp(tensor.min(-1)[0], max=0) wmax_tmp = torch.clamp(tensor.max(-1)[0], min=0) diff --git a/auto_round/export/export_to_gguf/packing.py b/auto_round/export/export_to_gguf/packing.py index 92b13d045..82d5ac50c 100644 --- a/auto_round/export/export_to_gguf/packing.py +++ b/auto_round/export/export_to_gguf/packing.py @@ -130,21 +130,28 @@ def torch_roundf(n): return torch.sign(n) * b -def make_qx_quants_chunk(data, bits, rmse_type=0, qw=None, split_num=1): +def make_qx_quants_chunk(data, bits, rmse_type=0, qw=None, split_num=1, v=0): """ Extreme VRAM-optimized version of quantization. - Processes data in chunks along the batch dimension (dim=0) to reduce peak memory usage. - Uses inplace operations to avoid unnecessary tensor copies. - Reuses buffers for temporary calculations wherever possible. + + Args: + v: Optional rounding perturbation. Either a scalar or a tensor with the + same shape as ``data``; added before each ``round`` to support + AutoRound-style tuning. """ nmax = 2 ** (bits - 1) scales_list = [] L_list = [] chunk_size = (data.shape[0] + split_num - 1) // split_num + v_is_tensor = isinstance(v, torch.Tensor) for start in range(0, data.shape[0], chunk_size): end = min(start + chunk_size, data.shape[0]) chunk = data[start:end] # Slice a batch chunk to reduce memory footprint + v_chunk = v[start:end] if v_is_tensor else v # Compute absolute values inplace to avoid extra tensor allocation chunk_abs = chunk.abs() @@ -156,7 +163,7 @@ def make_qx_quants_chunk(data, bits, rmse_type=0, qw=None, split_num=1): iscales = -nmax * get_reciprocal(group_max) # L buffer stores quantized values, modified inplace to save memory - L = (chunk * iscales).round_().clamp_(-nmax, nmax - 1) + L = (chunk * iscales + v_chunk).round_().clamp_(-nmax, nmax - 1) # Simple case: rmse_type == 0 if rmse_type == 0: @@ -204,7 +211,7 @@ def make_qx_quants_chunk(data, bits, rmse_type=0, qw=None, split_num=1): continue iscales_tmp = -(nmax + -0.1 * _is) / group_max # Use a temporary L buffer to avoid creating new large tensor - L_tmp = (chunk * iscales_tmp).round_().clamp_(-nmax, nmax - 1) + L_tmp = (chunk * iscales_tmp + v_chunk).round_().clamp_(-nmax, nmax - 1) sumlx_tmp = (w * chunk * L_tmp).sum(dim=-1) suml2_tmp = (w * L_tmp * L_tmp).sum(dim=-1) # Determine which elements should be replaced @@ -283,7 +290,7 @@ def make_qx_quants(data, bits, rmse_type=0, qw=None): return scales, L -def make_q3_quants(data, bits, do_rmse=False): +def make_q3_quants(data, bits, do_rmse=False, v=0): # Maximum absolute integer value for symmetric quantization nmax = 1 << (bits - 1) # equivalent to pow(2, bits-1) @@ -299,7 +306,7 @@ def make_q3_quants(data, bits, do_rmse=False): if do_rmse: # Initial quantization L (in-place round and clamp) L = torch.empty_like(data) - torch.round(iscale * data, out=L) + torch.round(iscale * data + v, out=L) L.clamp_(-nmax, nmax - 1) # Weight for RMSE = x^2 (in-place) @@ -349,7 +356,7 @@ def make_q3_quants(data, bits, do_rmse=False): # Fast path: quantize without RMSE (in-place round, clamp, shift) L = torch.empty_like(data) - torch.round(iscale * data, out=L) + torch.round(iscale * data + v, out=L) L.clamp_(-nmax, nmax - 1) L.add_(nmax) From b7041b7bb577e575156cb0bd7e503d8d10bd97a2 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 9 May 2026 07:46:56 +0000 Subject: [PATCH 04/13] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- auto_round/alg_ext.py | 12 ++++-------- auto_round/data_type/gguf.py | 26 +++++++++++--------------- auto_round/data_type/int.py | 6 +++--- 3 files changed, 18 insertions(+), 26 deletions(-) diff --git a/auto_round/alg_ext.py b/auto_round/alg_ext.py index c8c04e5e0..e8688a234 100644 --- a/auto_round/alg_ext.py +++ b/auto_round/alg_ext.py @@ -522,7 +522,7 @@ def _dq_asym_qdq(tensor, scale, wmin, bits, group_size, v=0): tensor, orig_shape, pad_len = reshape_pad_tensor_by_group_size(tensor, group_size) orig_dtype = tensor.dtype tensor = tensor.to(torch.float32) - maxq = int(2.0 ** bits) - 1 + maxq = int(2.0**bits) - 1 inverse_scale = get_reciprocal(scale) int_w = torch.clamp(round_ste((tensor + wmin) * inverse_scale + v), 0, maxq) qdq = (scale * int_w - wmin).to(orig_dtype) @@ -645,14 +645,10 @@ def _init_tuning_params_and_quant_func(self): self.weight_quant_func = quant_tensor_asym self.data_type = self.orig_layer.data_type if self.enable_act_quant: - from auto_round.data_type.gguf import ( - quant_tensor_gguf_asym_dq as _gguf_asym_dq, - quant_tensor_gguf_sym_dq as _gguf_sym_dq, - ) + from auto_round.data_type.gguf import quant_tensor_gguf_asym_dq as _gguf_asym_dq + from auto_round.data_type.gguf import quant_tensor_gguf_sym_dq as _gguf_sym_dq - self.act_quant_func = ( - _gguf_asym_dq if self.orig_layer.act_data_type == "int_asym_dq" else _gguf_sym_dq - ) + self.act_quant_func = _gguf_asym_dq if self.orig_layer.act_data_type == "int_asym_dq" else _gguf_sym_dq if self.enable_torch_compile: self.act_quant_func = compile_func(self.act_quant_func, self.device) self._init_params("act_max_scale", p_dtype, (1), 1.0, not self.orig_layer.act_dynamic) diff --git a/auto_round/data_type/gguf.py b/auto_round/data_type/gguf.py index 453f6eb0e..a0b2e63fc 100644 --- a/auto_round/data_type/gguf.py +++ b/auto_round/data_type/gguf.py @@ -78,7 +78,7 @@ def quant_tensor_sym_dq( scale = scale.view(-1, 1) zp = torch.full_like(scale, maxq) # pylint: disable=E1130 int_w = round_ste(tensor * get_reciprocal(scale) + v) - q = torch.clamp(int_w + zp, 0, int(2.0 ** bits) - 1) + q = torch.clamp(int_w + zp, 0, int(2.0**bits) - 1) qdq_result = (scale * (q - zp)).to(tensor.dtype) qdq_result = revert_tensor_by_pad(qdq_result, orig_shape=orig_shape, pad_len=pad_len) return qdq_result, {"scale": scale, "d_scale": d_scale}, zp @@ -116,7 +116,7 @@ def quant_tensor_asym_float_zp( Quantized and de-quantized tensor, scale, zero-point """ tensor, orig_shape, pad_len = reshape_pad_tensor_by_group_size(tensor, group_size) - maxq = int(2.0 ** bits) - 1 + maxq = int(2.0**bits) - 1 if tensor_min is None or tensor_max is None: wmin_tmp = torch.clamp(tensor.min(-1)[0], max=0) wmax_tmp = torch.clamp(tensor.max(-1)[0], min=0) @@ -171,7 +171,7 @@ def quant_tensor_asym_float_zp_rtn( """ tensor, orig_shape, pad_len = reshape_pad_tensor_by_group_size(tensor, group_size) orig_dtype = tensor.dtype - maxq = int(2.0 ** bits) - 1 + maxq = int(2.0**bits) - 1 if tensor_min is None or tensor_max is None: wmin = torch.clamp_max(tensor.min(-1)[0], 0) @@ -203,7 +203,7 @@ def quant_tensor_asym_float_zp_rtn( ## the values should be positive def double_quant_tensor(tensor, bits): tensor = tensor.to(torch.float32) # Ensure tensor is in float32 for precision - maxq = int(2.0 ** bits) - 1 + maxq = int(2.0**bits) - 1 wmax = torch.clamp(tensor.max(-1)[0], min=0) scale = wmax / maxq scale = scale.view(-1, 1) @@ -344,7 +344,7 @@ def quant_tensor_asym_dq( Quantized and de-quantized tensor, scale, zero-point """ tensor, orig_shape, pad_len = reshape_pad_tensor_by_group_size(tensor, group_size) - maxq = int(2.0 ** bits) - 1 + maxq = int(2.0**bits) - 1 if tensor_min is None or tensor_max is None: wmin_tmp = torch.clamp(tensor.min(-1)[0], max=0) wmax_tmp = torch.clamp(tensor.max(-1)[0], min=0) @@ -363,7 +363,7 @@ def quant_tensor_asym_dq( wmin = -wmin # pylint: disable=E1130 wmin = wmin.view(-1, super_group_size) - #Conduct double quant + # Conduct double quant scale, d_scale = double_quant_tensor(scale, super_bits) wmin, d_wmin = double_quant_tensor(wmin, super_bits) @@ -490,7 +490,7 @@ def search_gguf_scale_min_asym(tensor, bits=4, scale_dtype=torch.float16, imatri ) scale = scale.to(scale_dtype) scale = torch.where(torch.abs(scale) < 1e-30, torch.zeros_like(scale), scale) - nmax = int(2.0 ** super_bits) - 1 + nmax = int(2.0**super_bits) - 1 scale = scale.reshape(-1, super_group_size) wmin = wmin_0.reshape(-1, super_group_size) sum_quant_weights = quant_weights.sum(-1, keepdim=True).reshape(-1, super_group_size) @@ -534,7 +534,7 @@ def quant_tensor_gguf_asym_dq( if bits not in [2, 4, 5]: raise ValueError(f"bits={bits} not supported by rtn_int_asym_dq") orig_dtype = tensor.dtype - maxq = int(2.0 ** bits) - 1 + maxq = int(2.0**bits) - 1 group_size = 16 if bits == 2 else 32 split_num = 1 @@ -561,7 +561,7 @@ def iterative_wls_quant_search_chunk( ): dtype = torch.float32 data = data.to(dtype) - maxq = int(2.0 ** bits) - 1 + maxq = int(2.0**bits) - 1 minq = 0 weights = 1.0 if weights is None else weights.to(dtype) v_is_tensor = isinstance(v, torch.Tensor) @@ -719,9 +719,7 @@ def search_gguf_scale_min_sym(tensor, bits, imatrix, scale_dtype, split_num, v=0 scale, int_w = make_q3_quants(tensor, bits=bits, do_rmse=True, v=v) # scale, int_w = make_qx_quants(tensor, bits=bits, rmse_type=1, qw=None) elif bits == 6: - scale, int_w = make_qx_quants_chunk( - tensor, bits=bits, rmse_type=1, qw=None, split_num=split_num, v=v - ) + scale, int_w = make_qx_quants_chunk(tensor, bits=bits, rmse_type=1, qw=None, split_num=split_num, v=v) else: imatrix = imatrix.to(tensor.device) weights = imatrix.reshape(1, -1) @@ -730,9 +728,7 @@ def search_gguf_scale_min_sym(tensor, bits, imatrix, scale_dtype, split_num, v=0 quant_weights = _imatrix_handle_zero(quant_weights, tensor, bits) - scale, int_w = make_qx_quants_chunk( - tensor, bits=bits, rmse_type=1, qw=quant_weights, split_num=split_num, v=v - ) + scale, int_w = make_qx_quants_chunk(tensor, bits=bits, rmse_type=1, qw=quant_weights, split_num=split_num, v=v) if split_num > 1: clear_memory(device_list=[tensor.device]) return scale diff --git a/auto_round/data_type/int.py b/auto_round/data_type/int.py index 4a856a375..843ff2351 100644 --- a/auto_round/data_type/int.py +++ b/auto_round/data_type/int.py @@ -248,7 +248,7 @@ def quant_tensor_asym( Quantized and de-quantized tensor, scale, zero-point """ tensor, orig_shape, pad_len = reshape_pad_tensor_by_group_size(tensor, group_size) - maxq = int(2.0 ** bits) - 1 + maxq = int(2.0**bits) - 1 if tensor_min is None or tensor_max is None: wmin_tmp = torch.clamp(tensor.min(-1)[0], max=0) wmax_tmp = torch.clamp(tensor.max(-1)[0], min=0) @@ -305,7 +305,7 @@ def quant_tensor_sym_gptq( Quantized and de-quantized tensor, scale, zero-point """ tensor, orig_shape, pad_len = reshape_pad_tensor_by_group_size(tensor, group_size) - maxq = int(2.0 ** bits) - 1 + maxq = int(2.0**bits) - 1 if tensor_min is None or tensor_max is None: wmin_tmp = torch.clamp(tensor.min(-1)[0], max=0) wmax_tmp = torch.clamp(tensor.max(-1)[0], min=0) @@ -368,7 +368,7 @@ def quant_tensor_asym_wo_round( Quantized and de-quantize tensor, scale, zero-point """ tensor, orig_shape, pad_len = reshape_pad_tensor_by_group_size(tensor, group_size) - maxq = int(2.0 ** bits) - 1 + maxq = int(2.0**bits) - 1 if tensor_min is None or tensor_max is None: wmin_tmp = torch.clamp(tensor.min(-1)[0], max=0) wmax_tmp = torch.clamp(tensor.max(-1)[0], min=0) From e5defb1410c8a22dd0b0ef59dc3ad8c3c3af217b Mon Sep 17 00:00:00 2001 From: Wenhua Cheng Date: Sat, 9 May 2026 16:56:41 +0800 Subject: [PATCH 05/13] fix --- auto_round/alg_ext.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/auto_round/alg_ext.py b/auto_round/alg_ext.py index c8c04e5e0..f771a5fbc 100644 --- a/auto_round/alg_ext.py +++ b/auto_round/alg_ext.py @@ -689,7 +689,15 @@ def _run_search(self, weight, v): split_num=1, v=v_r, ) - return {"scale": scale, "wmin": wmin, "d_scale": d_scale, "d_wmin": d_wmin} + # Search funcs are decorated with ``@torch.inference_mode()``; their + # outputs are inference tensors and cannot be saved for backward. + # Clone to detach from inference mode so autograd may use them. + return { + "scale": scale.clone(), + "wmin": wmin.clone(), + "d_scale": d_scale.clone(), + "d_wmin": d_wmin.clone(), + } # sym path group_size = 16 @@ -709,7 +717,8 @@ def _run_search(self, weight, v): scale = torch.where(torch.abs(scale) < 1e-30, torch.zeros_like(scale), scale) scale, d_scale = double_quant_tensor_sym_rtn(scale, super_bits) scale = scale.unsqueeze(-1) - return {"scale": scale, "d_scale": d_scale} + # Clone to escape inference-mode tensors (see asym branch comment). + return {"scale": scale.clone(), "d_scale": d_scale.clone()} def _qdq_weight(self, value, min_scale, max_scale, scale_v=None, iter=None): """Quantizes and dequantizes weights with tuning parameters. From ae8ccd9be5f4ef023fccba120948b095d3ebe17b Mon Sep 17 00:00:00 2001 From: Wenhua Cheng Date: Tue, 12 May 2026 18:04:16 +0800 Subject: [PATCH 06/13] update --- auto_round/compressors_new/base.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/auto_round/compressors_new/base.py b/auto_round/compressors_new/base.py index 8507e107e..70186ab8f 100644 --- a/auto_round/compressors_new/base.py +++ b/auto_round/compressors_new/base.py @@ -614,12 +614,12 @@ def _adjust_torch_compile(self, enable_torch_compile: bool) -> None: self.enable_torch_compile = False logger.warning_once("reset enable_torch_compile to `False` as nvfp4 is enabled") super_group_size = getattr(cfg, "super_group_size", None) - enable_alg_ext = getattr(cfg, "enable_alg_ext", False) - if self.enable_torch_compile and super_group_size is not None and enable_alg_ext: - self.enable_torch_compile = False - logger.warning_once( - "reset enable_torch_compile to `False` as super_group_size is set for algorithm extension" - ) + # enable_alg_ext = getattr(cfg, "enable_alg_ext", False) + # if self.enable_torch_compile and super_group_size is not None and enable_alg_ext: + # self.enable_torch_compile = False + # logger.warning_once( + # "reset enable_torch_compile to `False` as super_group_size is set for algorithm extension" + # ) def _get_calibration_dataset(self) -> str: """Resolve calibration dataset: self.dataset > AutoScheme.dataset > default.""" From 1e60677cc2377d782ecd5410c743f8d32712259f Mon Sep 17 00:00:00 2001 From: Wenhua Cheng Date: Tue, 12 May 2026 21:10:06 +0800 Subject: [PATCH 07/13] update --- auto_round/compressors_new/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/auto_round/compressors_new/base.py b/auto_round/compressors_new/base.py index 70186ab8f..b9866e9d4 100644 --- a/auto_round/compressors_new/base.py +++ b/auto_round/compressors_new/base.py @@ -613,7 +613,7 @@ def _adjust_torch_compile(self, enable_torch_compile: bool) -> None: if self.enable_torch_compile and is_raw_nv_fp: self.enable_torch_compile = False logger.warning_once("reset enable_torch_compile to `False` as nvfp4 is enabled") - super_group_size = getattr(cfg, "super_group_size", None) + # super_group_size = getattr(cfg, "super_group_size", None) # enable_alg_ext = getattr(cfg, "enable_alg_ext", False) # if self.enable_torch_compile and super_group_size is not None and enable_alg_ext: # self.enable_torch_compile = False From 9a790a0677b6ac94b8cd1e4c67023fa6b5be189d Mon Sep 17 00:00:00 2001 From: Wenhua Cheng Date: Thu, 14 May 2026 17:53:40 +0800 Subject: [PATCH 08/13] fix --- auto_round/data_type/gguf.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/auto_round/data_type/gguf.py b/auto_round/data_type/gguf.py index 57c119f43..2155363e9 100644 --- a/auto_round/data_type/gguf.py +++ b/auto_round/data_type/gguf.py @@ -557,7 +557,7 @@ def quant_tensor_gguf_asym_dq( d_wmin = prev_d_wmin.detach() if scale is None: scale, wmin, d_scale, d_wmin = search_gguf_scale_min_asym( - tensor, bits, scale_dtype, imatrix, split_num=split_num + tensor, bits, scale_dtype, imatrix, split_num=split_num,v=v, ) scale = scale.clone() wmin = wmin.clone() @@ -807,7 +807,7 @@ def quant_tensor_gguf_sym_dq( scale = prev_scale.detach() d_scale = prev_d_scale.detach() if scale is None or d_scale is None: - scale = search_gguf_scale_min_sym(tensor, bits, imatrix, scale_dtype, split_num=split_num) + scale = search_gguf_scale_min_sym(tensor, bits, imatrix, scale_dtype, split_num=split_num,v=v) scale = scale.to(scale_dtype) scale = torch.where(torch.abs(scale) < 1e-30, torch.zeros_like(scale), scale) scale, d_scale = double_quant_tensor_sym_rtn(scale, super_bits) From 59d238d1730bddc11aab0d07aed522cb3f321ac8 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 14 May 2026 09:56:26 +0000 Subject: [PATCH 09/13] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- auto_round/data_type/gguf.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/auto_round/data_type/gguf.py b/auto_round/data_type/gguf.py index 2155363e9..e6c81a1c3 100644 --- a/auto_round/data_type/gguf.py +++ b/auto_round/data_type/gguf.py @@ -557,7 +557,12 @@ def quant_tensor_gguf_asym_dq( d_wmin = prev_d_wmin.detach() if scale is None: scale, wmin, d_scale, d_wmin = search_gguf_scale_min_asym( - tensor, bits, scale_dtype, imatrix, split_num=split_num,v=v, + tensor, + bits, + scale_dtype, + imatrix, + split_num=split_num, + v=v, ) scale = scale.clone() wmin = wmin.clone() @@ -807,7 +812,7 @@ def quant_tensor_gguf_sym_dq( scale = prev_scale.detach() d_scale = prev_d_scale.detach() if scale is None or d_scale is None: - scale = search_gguf_scale_min_sym(tensor, bits, imatrix, scale_dtype, split_num=split_num,v=v) + scale = search_gguf_scale_min_sym(tensor, bits, imatrix, scale_dtype, split_num=split_num, v=v) scale = scale.to(scale_dtype) scale = torch.where(torch.abs(scale) < 1e-30, torch.zeros_like(scale), scale) scale, d_scale = double_quant_tensor_sym_rtn(scale, super_bits) From 45c14b48afe03036eb0639a29b4a7f457a44ae3a Mon Sep 17 00:00:00 2001 From: Wenhua Cheng Date: Thu, 14 May 2026 18:12:29 +0800 Subject: [PATCH 10/13] update --- .../quantization/sign_roundv2/quantizer.py | 205 +++++++++++++++--- auto_round/compressors/base.py | 14 +- 2 files changed, 187 insertions(+), 32 deletions(-) diff --git a/auto_round/algorithms/quantization/sign_roundv2/quantizer.py b/auto_round/algorithms/quantization/sign_roundv2/quantizer.py index 36e5ab9dc..e6994d5fc 100644 --- a/auto_round/algorithms/quantization/sign_roundv2/quantizer.py +++ b/auto_round/algorithms/quantization/sign_roundv2/quantizer.py @@ -22,16 +22,66 @@ from auto_round.algorithms.quantization.sign_round.config import SignRoundConfig from auto_round.algorithms.quantization.sign_round.quantizer import SignRoundQuantizer -from auto_round.data_type.gguf import quant_tensor_gguf_asym_dq, quant_tensor_gguf_sym_dq +from auto_round.data_type.gguf import ( + double_quant_tensor_sym_rtn, + quant_tensor_gguf_asym_dq, + quant_tensor_gguf_sym_dq, + search_gguf_scale_min_asym, + search_gguf_scale_min_sym, +) from auto_round.data_type.int import quant_tensor_asym, quant_tensor_sym, search_scales from auto_round.data_type.mxfp import quant_mx, search_mx_scale from auto_round.data_type.nvfp import nv_fp4, search_nvfp4_scale -from auto_round.data_type.utils import reshape_pad_tensor_by_group_size +from auto_round.data_type.utils import ( + reshape_pad_tensor_by_group_size, + revert_tensor_by_pad, + round_ste, +) from auto_round.logger import logger -from auto_round.utils import check_to_quantized, compile_func +from auto_round.utils import check_to_quantized, compile_func, get_reciprocal from auto_round.wrapper import WrapperLinear, wrapper_block +def _dq_asym_qdq(tensor, scale, wmin, bits, group_size, v=0): + """Pure asym double-quant qdq math given precomputed scale/wmin (compilable).""" + tensor, orig_shape, pad_len = reshape_pad_tensor_by_group_size(tensor, group_size) + orig_dtype = tensor.dtype + tensor = tensor.to(torch.float32) + maxq = int(2.0**bits) - 1 + inverse_scale = get_reciprocal(scale) + int_w = torch.clamp(round_ste((tensor + wmin) * inverse_scale + v), 0, maxq) + qdq = (scale * int_w - wmin).to(orig_dtype) + qdq = revert_tensor_by_pad(qdq, orig_shape=orig_shape, pad_len=pad_len) + return qdq + + +def _dq_sym_qdq(tensor, scale, bits, v=0): + """Pure sym double-quant qdq math given precomputed scale (compilable).""" + from auto_round.export.export_to_gguf.config import GGML_QUANT_SIZES, QK_K + + group_size = 16 + super_group_size = 16 + maxq = int(2.0 ** (bits - 1)) + tensor, orig_shape, pad_len = reshape_pad_tensor_by_group_size(tensor, group_size) + orig_dtype = tensor.dtype + tensor = tensor.to(torch.float32) + ggml_type = f"q{bits}_k" + block_size, _ = GGML_QUANT_SIZES[ggml_type] + n_blocks = tensor.nelement() // block_size + tensor = tensor.reshape(n_blocks, super_group_size, QK_K // super_group_size) + if isinstance(v, torch.Tensor): + v_r, _, _ = reshape_pad_tensor_by_group_size(v, group_size) + v_r = v_r.reshape(n_blocks, super_group_size, QK_K // super_group_size) + else: + v_r = v + zp = torch.full_like(scale, maxq) + inverse_scale = get_reciprocal(scale) + int_w = round_ste(tensor * inverse_scale + v_r).clip(-maxq, maxq - 1) + maxq + qdq = (scale * (int_w - zp)).to(orig_dtype) + qdq = revert_tensor_by_pad(qdq, orig_shape=orig_shape, pad_len=pad_len) + return qdq + + def _named_wrapper_block(wrapper_cls, name: str): wrapped = partial(wrapper_block, wrapper_cls=wrapper_cls) wrapped.__name__ = name @@ -92,10 +142,21 @@ def __init__(self, *args, **kwargs): def _init_tuning_params_and_quant_func(self): super()._init_tuning_params_and_quant_func() + # The double-quant search path is data-dependent and kept un-compiled, + # while ``weight_quant_func`` is the compilable pure-math half. + self._is_dq_path = False + self._dq_kind = None + self.search_func = None if hasattr(self.orig_layer, "super_group_size") and self.orig_layer.super_group_size is not None: - self.weight_quant_func = ( - quant_tensor_gguf_asym_dq if self.orig_layer.data_type == "int_asym_dq" else quant_tensor_gguf_sym_dq - ) + self._is_dq_path = True + if self.orig_layer.data_type == "int_asym_dq": + self.search_func = search_gguf_scale_min_asym + self.weight_quant_func = _dq_asym_qdq + self._dq_kind = "asym" + else: + self.search_func = search_gguf_scale_min_sym + self.weight_quant_func = _dq_sym_qdq + self._dq_kind = "sym" elif self.orig_layer.sym: self.weight_quant_func = quant_tensor_sym else: @@ -113,26 +174,120 @@ def _init_tuning_params_and_quant_func(self): if self.enable_torch_compile: self.weight_quant_func = compile_func(self.weight_quant_func, self.device) + @torch.no_grad() + def _run_search(self, weight, v): + """Per-format scale/wmin search separated from the (compilable) quant func.""" + from auto_round.export.export_to_gguf.config import GGML_QUANT_SIZES, QK_K + + bits = self.orig_layer.bits + scale_dtype = self.orig_layer.scale_dtype + imatrix = getattr(self.orig_layer, "imatrix", None) + + if self._dq_kind == "asym": + group_size = 16 if bits == 2 else 32 + t, _, _ = reshape_pad_tensor_by_group_size(weight.to(torch.float32), group_size) + v_r = v + if isinstance(v, torch.Tensor): + v_r, _, _ = reshape_pad_tensor_by_group_size(v, group_size) + scale, wmin, d_scale, d_wmin = self.search_func( + t, + bits=bits, + scale_dtype=scale_dtype, + imatrix=imatrix, + split_num=1, + v=v_r, + ) + # Search funcs use ``@torch.inference_mode()``; clone to detach so + # autograd may consume them. + return { + "scale": scale.clone(), + "wmin": wmin.clone(), + "d_scale": d_scale.clone(), + "d_wmin": d_wmin.clone(), + } + + # sym path + group_size = 16 + super_group_size = 16 + t, _, _ = reshape_pad_tensor_by_group_size(weight.to(torch.float32), group_size) + ggml_type = f"q{bits}_k" + block_size, _ = GGML_QUANT_SIZES[ggml_type] + n_blocks = t.nelement() // block_size + t = t.reshape(n_blocks, super_group_size, QK_K // super_group_size) + v_r = v + if isinstance(v, torch.Tensor): + v_r, _, _ = reshape_pad_tensor_by_group_size(v, group_size) + v_r = v_r.reshape(n_blocks, super_group_size, QK_K // super_group_size) + super_bits = 6 if bits == 3 else 8 + scale = self.search_func(t, bits, imatrix, scale_dtype, split_num=1, v=v_r) + scale = scale.to(scale_dtype) + scale = torch.where(torch.abs(scale) < 1e-30, torch.zeros_like(scale), scale) + scale, d_scale = double_quant_tensor_sym_rtn(scale, super_bits) + scale = scale.unsqueeze(-1) + return {"scale": scale.clone(), "d_scale": d_scale.clone()} + def _qdq_weight(self, value, min_scale, max_scale): - weight_q, scale, zp = super()._qdq_weight(value, min_scale, max_scale) - if isinstance(scale, dict) and "d_scale" in scale and self.prev_scale is None: - self.prev_scale = scale["scale"] - self.prev_d_scale = scale["d_scale"] - if isinstance(zp, dict): - self.prev_wmin = zp["wmin"] - self.prev_d_wmin = zp["d_wmin"] - elif self.prev_scale is None: - self.prev_scale = scale - return weight_q, scale, zp - - def _extra_quant_kwargs(self): - return { - "prev_scale": self.prev_scale, - "prev_wmin": self.prev_wmin, - "prev_d_scale": self.prev_d_scale, - "prev_d_wmin": self.prev_d_wmin, - "iter": getattr(self, "cur_iter", None), - } + if not self._is_dq_path: + # Non-dq path keeps the original behavior (base class handles it). + return super()._qdq_weight(value, min_scale, max_scale) + + if self.orig_layer.bits >= 16: + return self.orig_layer.weight, None, None + min_bound, max_bound = self.minmax_scale_bound + min_scale.data.clamp_(min_bound, max_bound) + max_scale.data.clamp_(min_bound, max_bound) + weight = self.orig_layer.weight + if weight.device.type == "meta": + weight = self.orig_layer.get_weight().to(self.device) + if isinstance(self.orig_layer, transformers.pytorch_utils.Conv1D): + weight = weight.t() + + # Re-search every 10 steps; otherwise reuse the cached search results. + iter_v = getattr(self, "cur_iter", 0) + need_search = (iter_v % 10 == 0) or (iter_v == -1) or (self.prev_scale is None) + if need_search: + params = self._run_search(weight, value) + self.prev_scale = params["scale"] + self.prev_d_scale = params["d_scale"] + if self._dq_kind == "asym": + self.prev_wmin = params["wmin"] + self.prev_d_wmin = params["d_wmin"] + else: + params = { + "scale": self.prev_scale.detach(), + "d_scale": self.prev_d_scale.detach(), + } + if self._dq_kind == "asym": + params["wmin"] = self.prev_wmin.detach() + params["d_wmin"] = self.prev_d_wmin.detach() + + bits = self.orig_layer.bits + if self._dq_kind == "asym": + group_size = 16 if bits == 2 else 32 + weight_q = self.weight_quant_func( + weight, + params["scale"], + params["wmin"], + bits, + group_size, + v=value, + ) + scale_out = {"scale": params["scale"], "d_scale": params["d_scale"]} + zp_out = {"wmin": params["wmin"], "d_wmin": params["d_wmin"]} + else: + weight_q = self.weight_quant_func( + weight, + params["scale"], + bits, + v=value, + ) + scale_out = {"scale": params["scale"], "d_scale": params["d_scale"]} + zp_out = torch.full_like(params["scale"], int(2.0 ** (bits - 1))) + + weight_q = weight_q.to(weight.dtype) + if isinstance(self.orig_layer, transformers.pytorch_utils.Conv1D): + weight_q = weight_q.t() + return weight_q, scale_out, zp_out class SignRoundV2Quantizer(SignRoundQuantizer): diff --git a/auto_round/compressors/base.py b/auto_round/compressors/base.py index d70010dc1..44027ce2f 100644 --- a/auto_round/compressors/base.py +++ b/auto_round/compressors/base.py @@ -644,13 +644,13 @@ def _adjust_torch_compile(self, enable_torch_compile: bool) -> None: if self.enable_torch_compile and is_raw_nv_fp: self.enable_torch_compile = False logger.warning_once("reset enable_torch_compile to `False` as nvfp4 is enabled") - super_group_size = getattr(cfg, "super_group_size", None) - enable_alg_ext = getattr(cfg, "enable_alg_ext", False) - if self.enable_torch_compile and super_group_size is not None and enable_alg_ext: - self.enable_torch_compile = False - logger.warning_once( - "reset enable_torch_compile to `False` as super_group_size is set for algorithm extension" - ) + #super_group_size = getattr(cfg, "super_group_size", None) + # enable_alg_ext = getattr(cfg, "enable_alg_ext", False) + # if self.enable_torch_compile and super_group_size is not None and enable_alg_ext: + # self.enable_torch_compile = False + # logger.warning_once( + # "reset enable_torch_compile to `False` as super_group_size is set for algorithm extension" + # ) def _get_calibration_dataset(self) -> str: """Resolve calibration dataset: self.dataset > AutoScheme.dataset > default.""" From e3d0eb9a89b63c6b5b17a69a26331b9aeb9d81fd Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 14 May 2026 10:13:17 +0000 Subject: [PATCH 11/13] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- auto_round/compressors/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/auto_round/compressors/base.py b/auto_round/compressors/base.py index 44027ce2f..25a4308f7 100644 --- a/auto_round/compressors/base.py +++ b/auto_round/compressors/base.py @@ -644,7 +644,7 @@ def _adjust_torch_compile(self, enable_torch_compile: bool) -> None: if self.enable_torch_compile and is_raw_nv_fp: self.enable_torch_compile = False logger.warning_once("reset enable_torch_compile to `False` as nvfp4 is enabled") - #super_group_size = getattr(cfg, "super_group_size", None) + # super_group_size = getattr(cfg, "super_group_size", None) # enable_alg_ext = getattr(cfg, "enable_alg_ext", False) # if self.enable_torch_compile and super_group_size is not None and enable_alg_ext: # self.enable_torch_compile = False From 52203101c73b13d841010ea70faf177e6ee0e18b Mon Sep 17 00:00:00 2001 From: Wenhua Cheng Date: Fri, 15 May 2026 13:50:19 +0800 Subject: [PATCH 12/13] fix ut --- .../quantization/sign_roundv2/quantizer.py | 12 +++++++++--- auto_round/data_type/gguf.py | 3 ++- auto_round/export/export_to_gguf/packing.py | 3 ++- 3 files changed, 13 insertions(+), 5 deletions(-) diff --git a/auto_round/algorithms/quantization/sign_roundv2/quantizer.py b/auto_round/algorithms/quantization/sign_roundv2/quantizer.py index e6994d5fc..75a730d58 100644 --- a/auto_round/algorithms/quantization/sign_roundv2/quantizer.py +++ b/auto_round/algorithms/quantization/sign_roundv2/quantizer.py @@ -134,6 +134,9 @@ class SignRoundDQWrapperLinear(WrapperLinear): minmax_scale_bound = (0.5, 1.5) def __init__(self, *args, **kwargs): + if "enable_minmax_tuning" in kwargs: + logger.warning_once("diable minmax tuning for a little better accuracy and lower cost") + kwargs["enable_minmax_tuning"] = False # a little faster and better super().__init__(*args, **kwargs) self.prev_scale = None self.prev_wmin = None @@ -234,7 +237,7 @@ def _qdq_weight(self, value, min_scale, max_scale): if self.orig_layer.bits >= 16: return self.orig_layer.weight, None, None min_bound, max_bound = self.minmax_scale_bound - min_scale.data.clamp_(min_bound, max_bound) + min_scale.data.clamp_(min_bound, max_bound) # TODO this one could be deleted max_scale.data.clamp_(min_bound, max_bound) weight = self.orig_layer.weight if weight.device.type == "meta": @@ -244,7 +247,7 @@ def _qdq_weight(self, value, min_scale, max_scale): # Re-search every 10 steps; otherwise reuse the cached search results. iter_v = getattr(self, "cur_iter", 0) - need_search = (iter_v % 10 == 0) or (iter_v == -1) or (self.prev_scale is None) + need_search = (iter_v==0) or (iter_v == -1) or (self.prev_scale is None) if need_search: params = self._run_search(weight, value) self.prev_scale = params["scale"] @@ -308,7 +311,10 @@ def __init__(self, config: SignRoundConfig): "algorithm extension has only undergone limited validation on " "W2A16,INT4, MXFP4 and NVFP4; use with caution." ) - self._use_outlier_suppressed_loss = True + if self.act_bits<=4 or self.bits<4: + self._use_outlier_suppressed_loss = True + else: + self._use_outlier_suppressed_loss = False self.wrapper_block = _named_wrapper_block(SignRoundOptimizedWrapperLinear, "wrapper_block") if self.data_type.endswith("dq"): diff --git a/auto_round/data_type/gguf.py b/auto_round/data_type/gguf.py index e6c81a1c3..7fea96c9a 100644 --- a/auto_round/data_type/gguf.py +++ b/auto_round/data_type/gguf.py @@ -587,7 +587,8 @@ def iterative_wls_quant_search_chunk( maxq = int(2.0**bits) - 1 minq = 0 weights = 1.0 if weights is None else weights.to(dtype) - v_is_tensor = isinstance(v, torch.Tensor) + # A 0-dim tensor (scalar tensor) cannot be sliced; treat it like a Python scalar. + v_is_tensor = isinstance(v, torch.Tensor) and v.dim() > 0 results_scale = [] results_rmin = [] diff --git a/auto_round/export/export_to_gguf/packing.py b/auto_round/export/export_to_gguf/packing.py index 82d5ac50c..b1b60dc6b 100644 --- a/auto_round/export/export_to_gguf/packing.py +++ b/auto_round/export/export_to_gguf/packing.py @@ -147,7 +147,8 @@ def make_qx_quants_chunk(data, bits, rmse_type=0, qw=None, split_num=1, v=0): scales_list = [] L_list = [] chunk_size = (data.shape[0] + split_num - 1) // split_num - v_is_tensor = isinstance(v, torch.Tensor) + # A 0-dim tensor (scalar tensor) cannot be sliced; treat it like a Python scalar. + v_is_tensor = isinstance(v, torch.Tensor) and v.dim() > 0 for start in range(0, data.shape[0], chunk_size): end = min(start + chunk_size, data.shape[0]) chunk = data[start:end] # Slice a batch chunk to reduce memory footprint From dcc8dbab957001fb6f8c3266c3b4b35b6bcb1231 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 15 May 2026 05:54:39 +0000 Subject: [PATCH 13/13] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../algorithms/quantization/sign_roundv2/quantizer.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/auto_round/algorithms/quantization/sign_roundv2/quantizer.py b/auto_round/algorithms/quantization/sign_roundv2/quantizer.py index 75a730d58..145215dd1 100644 --- a/auto_round/algorithms/quantization/sign_roundv2/quantizer.py +++ b/auto_round/algorithms/quantization/sign_roundv2/quantizer.py @@ -135,8 +135,8 @@ class SignRoundDQWrapperLinear(WrapperLinear): def __init__(self, *args, **kwargs): if "enable_minmax_tuning" in kwargs: - logger.warning_once("diable minmax tuning for a little better accuracy and lower cost") - kwargs["enable_minmax_tuning"] = False # a little faster and better + logger.warning_once("disable minmax tuning for a little better accuracy and lower cost") + kwargs["enable_minmax_tuning"] = False # a little faster and better super().__init__(*args, **kwargs) self.prev_scale = None self.prev_wmin = None @@ -237,7 +237,7 @@ def _qdq_weight(self, value, min_scale, max_scale): if self.orig_layer.bits >= 16: return self.orig_layer.weight, None, None min_bound, max_bound = self.minmax_scale_bound - min_scale.data.clamp_(min_bound, max_bound) # TODO this one could be deleted + min_scale.data.clamp_(min_bound, max_bound) # TODO this one could be deleted max_scale.data.clamp_(min_bound, max_bound) weight = self.orig_layer.weight if weight.device.type == "meta": @@ -247,7 +247,7 @@ def _qdq_weight(self, value, min_scale, max_scale): # Re-search every 10 steps; otherwise reuse the cached search results. iter_v = getattr(self, "cur_iter", 0) - need_search = (iter_v==0) or (iter_v == -1) or (self.prev_scale is None) + need_search = (iter_v == 0) or (iter_v == -1) or (self.prev_scale is None) if need_search: params = self._run_search(weight, value) self.prev_scale = params["scale"] @@ -311,7 +311,7 @@ def __init__(self, config: SignRoundConfig): "algorithm extension has only undergone limited validation on " "W2A16,INT4, MXFP4 and NVFP4; use with caution." ) - if self.act_bits<=4 or self.bits<4: + if self.act_bits <= 4 or self.bits < 4: self._use_outlier_suppressed_loss = True else: self._use_outlier_suppressed_loss = False