diff --git a/reference_hierarchical_nvfp4/__init__.py b/reference_hierarchical_nvfp4/__init__.py new file mode 100644 index 0000000000..a0b3192cb8 --- /dev/null +++ b/reference_hierarchical_nvfp4/__init__.py @@ -0,0 +1,36 @@ +"""NVFP4 reference: same *recipe* as TE (S_enc, S_dec fp8, bsi, FP4) except +``S_enc = (fp8_max*fp4_max)/amax`` uses each **1x64** window's max instead of +per-tensor global amax. See ``fp8_e4m3_utils_np`` and ``core_nvfp4.cuh``. + +- PyTorch: symbols below (requires ``torch`` + ``numpy`` for E4M3). +- CPU: ``hierarchical_nvfp4_ref_numpy`` (``numpy`` only), or run + ``python reference_hierarchical_nvfp4/hierarchical_nvfp4_ref_numpy.py``. +""" + +from .hierarchical_nvfp4_ref import ( + COARSE, + FINE, + HierarchicalNVFP4Colwise, + HierarchicalNVFP4Rowwise, + dequantize_colwise, + dequantize_rowwise, + fp4_e2m1_grid_torch, + quantize_columnwise_1x64_1x16, + quantize_rowwise_1x64_1x16, + reference_matmul_tn, + roundtrip_error, +) + +__all__ = [ + "COARSE", + "FINE", + "HierarchicalNVFP4Colwise", + "HierarchicalNVFP4Rowwise", + "dequantize_colwise", + "dequantize_rowwise", + "fp4_e2m1_grid_torch", + "quantize_columnwise_1x64_1x16", + "quantize_rowwise_1x64_1x16", + "reference_matmul_tn", + "roundtrip_error", +] diff --git a/reference_hierarchical_nvfp4/compare_64x64_global_vs_1x64.py b/reference_hierarchical_nvfp4/compare_64x64_global_vs_1x64.py new file mode 100644 index 0000000000..cfdad13cfa --- /dev/null +++ b/reference_hierarchical_nvfp4/compare_64x64_global_vs_1x64.py @@ -0,0 +1,116 @@ +# Compare: TE-style NVFP4 (single S_enc from *global* amax) vs +# reference 1x64 (S_enc from max|x| in each 1x64 K-window = per-row for K=64). +# Matrix shape (M, K) = (64, 64), rowwise along K, 1x16 blocks * 4 per row. +# +# Run (no torch): python3 reference_hierarchical_nvfp4/compare_64x64_global_vs_1x64.py + +from __future__ import annotations + +import os +import sys + +import numpy as np + +# Load sibling ref modules (avoid package __init__ -> torch) +_REF_DIR = os.path.dirname(os.path.abspath(__file__)) +if _REF_DIR not in sys.path: + sys.path.insert(0, _REF_DIR) + +from fp8_e4m3_utils_np import ( + TINY, + compute_S_dec_f32_before_cast_te, + compute_S_enc_from_amax_1x64_like_te, + e4m3_u8_to_f32, + f32_to_e4m3_u8, +) +from hierarchical_nvfp4_ref_numpy import ( + FINE, + FP4_E2M1_GRID, + dequantize_rowwise, + quantize_rowwise_1x64_1x16, +) + +M = 64 +K = 64 + + +def _round_nearest_fp4(x: np.ndarray) -> np.ndarray: + d = np.abs(x[..., None] - FP4_E2M1_GRID) + return FP4_E2M1_GRID[np.argmin(d, axis=-1).astype(np.int64)] + + +def te_nvfp4_rowwise_global_senc( + x: np.ndarray, eps: float = 1e-12 +) -> tuple[np.ndarray, float, np.ndarray, np.ndarray]: + """ + Same math as ref but S_enc = 2688 / global_amax (single float for all blocks). + Returns: x_recon, global_amax, w_pre_fp4, S_dec_u8 (M, n16) per-block fp8 + """ + x = np.asarray(x, np.float32) + m, k = x.shape + g_amax = float(np.max(np.abs(x))) + S_g = compute_S_enc_from_amax_1x64_like_te(g_amax) + n16 = (k + FINE - 1) // FINE + w = np.empty_like(x) + s_dec = np.empty((m, n16), dtype=np.uint8) + for r in range(m): + t16b = 0 + while t16b * FINE < k: + lo, hi = t16b * FINE, min((t16b + 1) * FINE, k) + segx = x[r, lo:hi] + bamax = float(np.max(np.abs(segx))) + if bamax < eps: + bamax = float(eps) + raw = compute_S_dec_f32_before_cast_te(bamax, S_g) + u = f32_to_e4m3_u8(np.array([raw], dtype=np.float32).reshape(1)) + s_dec[r, t16b] = u.reshape(-1)[0] + s_d = max(float(e4m3_u8_to_f32(s_dec[r : r + 1, t16b : t16b + 1]).reshape(-1)[0]), TINY) + bsi = S_g / s_d + w[r, lo:hi] = segx * bsi + t16b += 1 + q = _round_nearest_fp4(w) + t16g = (np.arange(k) // FINE).astype(np.int64) + sde = e4m3_u8_to_f32(s_dec[:, t16g].astype(np.uint8)) + sde = np.maximum(sde, TINY) + x_recon = q * (sde / S_g) + return x_recon.astype(np.float32), g_amax, w, s_dec + + +def main() -> None: + rng = np.random.default_rng(2026) + # "Real" data: not uniform — heavy-tailed + one row scaled up for local/global gap + x = rng.standard_normal((M, K)).astype(np.float32) + x *= 0.35 + x[7, :] *= 4.0 + x[0:8, 12:20] += 0.4 + + x_ref1 = quantize_rowwise_1x64_1x16(x) + recon_1x64 = dequantize_rowwise(x_ref1) + recon_global, g_amax, _, _ = te_nvfp4_rowwise_global_senc(x) + + d = np.abs(recon_1x64 - x) + d2 = np.abs(recon_global - x) + dg = np.abs(recon_1x64 - recon_global) + + print("=== 64x64 数值对比 (rowwise, K=4×16) ===") + print("global_amax =", g_amax) + print("S_enc 现网(全局) = 2688 / global_amax =", compute_S_enc_from_amax_1x64_like_te(g_amax)) + print("---") + print("量化再反归一 vs 原张量: max abs err [1x64 S_enc 参考] :", float(np.max(d))) + print("量化再反归一 vs 原张量: max abs err [全局 S_enc] :", float(np.max(d2))) + print("RMS 误差 vs 原张量 [1x64]:", float(np.sqrt(np.mean(d**2)))) + print("RMS 误差 vs 原张量 [全局]:", float(np.sqrt(np.mean(d2**2)))) + print("---") + print("两种重建之间的 max abs 差 |recon_1x64 - recon_global| :", float(np.max(dg))) + print("RMS( recon_1x64 - recon_global ) :", float(np.sqrt(np.mean(dg**2)))) + fn = float(np.linalg.norm(x, "fro")) + if fn > 0: + print("||x||_F =", fn) + print( + "Fro 相对: ||recon_1x64 - recon_global||_F / ||x||_F =", + float(np.linalg.norm(dg, "fro") / fn), + ) + + +if __name__ == "__main__": + main() diff --git a/reference_hierarchical_nvfp4/fp8_e4m3_utils_np.py b/reference_hierarchical_nvfp4/fp8_e4m3_utils_np.py new file mode 100644 index 0000000000..cecd42a537 --- /dev/null +++ b/reference_hierarchical_nvfp4/fp8_e4m3_utils_np.py @@ -0,0 +1,71 @@ +# FP8 E4M3: decode uint8->float32 and encode f32->nearest uint8 (all 256 codes). +# Matches OCP/FP8 E4M3 for reference roundtrip; close to CUDA static_cast(f32). + +from __future__ import annotations + +import numpy as np + +# Match transformer_engine::detail::TypeExtrema (common.h) / nvfp4.cu kernels +FP8_E4M3_FMAX: float = 448.0 +FP4_E2M1_FMAX: float = 6.0 +S_ENC_NUMER: float = FP8_E4M3_FMAX * FP4_E2M1_FMAX # 2688 +FLT_MAX: float = 3.402823466e38 +TINY: float = 1.17549435e-38 + + +def _decode_e4m3_byte(b: int) -> float: + u = b & 0xFF + sign = u >> 7 + exp = (u >> 3) & 0x0F + man = u & 0x07 + if exp == 0: + if man == 0: + return 0.0 + v = (man / 8.0) * (2.0 ** (-6)) + else: + v = (1.0 + man / 8.0) * (2.0 ** (exp - 7)) + return -v if sign else v + + +# Precompute 256 float32 values for all E4M3 codes +_E4M3_TABLE: np.ndarray = np.array([_decode_e4m3_byte(i) for i in range(256)], dtype=np.float32) + + +def f32_to_e4m3_u8(x: np.ndarray) -> np.ndarray: + """Round each element to nearest fp8e4m3 (by L_inf on 256 codes). x can be any shape.""" + x = np.asarray(x, dtype=np.float32) + flat = x.ravel()[:, None] + d = np.abs(flat - _E4M3_TABLE[None, :]) + out = np.argmin(d, axis=1).astype(np.uint8) + return out.reshape(x.shape) + + +def e4m3_u8_to_f32(b: np.ndarray) -> np.ndarray: + return _E4M3_TABLE[np.asarray(b, dtype=np.int32) & 0xFF].astype(np.float32) + + +def compute_S_enc_from_amax_1x64_like_te(amax: float, eps: float = TINY) -> float: + """ + Same as compute_global_encode_scaling_factor_FP4 in core_nvfp4.cuh, but *amax* is + the 1x64 local max (replaces per-tensor global amax in this ref). + """ + a = float(amax) + if a <= 0.0 or not np.isfinite(a): + return 1.0 + safe = max(a, eps) + g = S_ENC_NUMER / safe + return float(min(g, FLT_MAX)) + + +def compute_S_dec_f32_before_cast_te(block_amax: float, S_enc: float) -> float: + """ + Unquantized S_dec = block_amax * (S_enc / 6) before cast to e4m3 + (compute_decoding_scaling_factor, quantization_SF in core_nvfp4.cuh). + """ + return float(np.float32(block_amax * (S_enc * (1.0 / FP4_E2M1_FMAX)))) + + +def f32_e4m3_f32(x: float) -> float: + """Cast pipeline f32 -> fp8e4m3 -> f32, numpy.""" + u = f32_to_e4m3_u8(np.array([x], dtype=np.float32)) + return float(e4m3_u8_to_f32(u)[0]) diff --git a/reference_hierarchical_nvfp4/hierarchical_nvfp4_ref.py b/reference_hierarchical_nvfp4/hierarchical_nvfp4_ref.py new file mode 100644 index 0000000000..8df743d781 --- /dev/null +++ b/reference_hierarchical_nvfp4/hierarchical_nvfp4_ref.py @@ -0,0 +1,293 @@ +# PyTorch twin of `hierarchical_nvfp4_ref_numpy.py`. +# +# Aligned with TE NVFP4 (core_nvfp4.cuh, quantize_nvfp4.cuh) for: +# S_enc, S_dec = cast_f32(block_amax * S_enc / 6) as fp8e4m3, block_scale_inv = S_enc / f32(S_dec), +# then FP4 from x * bsi. **S_enc** uses **1x64 local amax** instead of per-tensor global amax. +# +# Standalone. CPU twin: hierarchical_nvfp4_ref_numpy.py + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Tuple + +import torch + +import numpy as np + +from .fp8_e4m3_utils_np import TINY, compute_S_dec_f32_before_cast_te +from .fp8_e4m3_utils_np import compute_S_enc_from_amax_1x64_like_te as _s_enc_from_amax_cpu + +try: + from .fp8_e4m3_utils_np import e4m3_u8_to_f32 as _e4m3_u8_f32_np + from .fp8_e4m3_utils_np import f32_to_e4m3_u8 as _f32_e4m3_u8_np +except ImportError: + import os + import sys + + sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) + from fp8_e4m3_utils_np import e4m3_u8_to_f32 as _e4m3_u8_f32_np + from fp8_e4m3_utils_np import f32_to_e4m3_u8 as _f32_e4m3_u8_np + +COARSE = 64 +FINE = 16 + + +def fp4_e2m1_grid_torch(device, dtype) -> torch.Tensor: + vals = ( + 0.0, + 0.5, + 1.0, + 1.5, + 2.0, + 3.0, + 4.0, + 6.0, + -0.0, + -0.5, + -1.0, + -1.5, + -2.0, + -3.0, + -4.0, + -6.0, + ) + return torch.tensor(vals, device=device, dtype=dtype) + + +def _round_to_nearest_fp4(x: torch.Tensor, grid: torch.Tensor) -> torch.Tensor: + diff = (x.unsqueeze(-1) - grid).abs() + return grid[diff.argmin(dim=-1)] + + +def _pack_fp4_along_k(x_fp4: torch.Tensor) -> torch.Tensor: + m, k = x_fp4.shape + grid = fp4_e2m1_grid_torch(x_fp4.device, torch.float32) + diff = (x_fp4.unsqueeze(-1) - grid).abs() + nibble = diff.argmin(dim=-1).to(torch.uint8) + n_pairs = (k + 1) // 2 + out = torch.zeros(m, n_pairs, dtype=torch.uint8, device=x_fp4.device) + for p in range(k // 2): + j = 2 * p + out[:, p] = (nibble[:, j] & 0x0F) | ((nibble[:, j + 1] & 0x0F) << 4) + if k % 2 == 1: + out[:, -1] = nibble[:, -1] & 0x0F + return out + + +def _unpack_fp4_along_k(data_u8: torch.Tensor, k: int) -> torch.Tensor: + m = data_u8.size(0) + grid = fp4_e2m1_grid_torch(data_u8.device, torch.float32) + out = torch.empty(m, k, device=data_u8.device, dtype=grid.dtype) + p = 0 + for j in range(0, k - 1, 2): + b = data_u8[:, p] + p += 1 + out[:, j] = grid[(b & 0x0F).to(torch.long)] + out[:, j + 1] = grid[((b >> 4) & 0x0F).to(torch.long)] + if k % 2 == 1: + b = data_u8[:, p] + out[:, -1] = grid[(b & 0x0F).to(torch.long)] + return out + + +def _pack_fp4_along_m(x_fp4: torch.Tensor) -> torch.Tensor: + m, k = x_fp4.shape + grid = fp4_e2m1_grid_torch(x_fp4.device, torch.float32) + diff = (x_fp4.unsqueeze(-1) - grid).abs() + nibble = diff.argmin(dim=-1).to(torch.uint8) + n_pairs = (m + 1) // 2 + out = torch.zeros(n_pairs, k, dtype=torch.uint8, device=x_fp4.device) + for p in range(m // 2): + r = 2 * p + out[p, :] = (nibble[r, :] & 0x0F) | ((nibble[r + 1, :] & 0x0F) << 4) + if m % 2 == 1: + out[-1, :] = nibble[-1, :] & 0x0F + return out + + +def _unpack_fp4_along_m(data_u8: torch.Tensor, m: int, k: int) -> torch.Tensor: + grid = fp4_e2m1_grid_torch(data_u8.device, torch.float32) + out = torch.empty(m, k, device=data_u8.device, dtype=grid.dtype) + p = 0 + for r in range(0, m - 1, 2): + b = data_u8[p, :] + p += 1 + out[r, :] = grid[(b & 0x0F).to(torch.long)] + out[r + 1, :] = grid[((b >> 4) & 0x0F).to(torch.long)] + if m % 2 == 1: + b = data_u8[p, :] + out[-1, :] = grid[(b & 0x0F).to(torch.long)] + return out + + +@dataclass +class HierarchicalNVFP4Rowwise: + m: int + k: int + data_u8: torch.Tensor + S_enc: torch.Tensor + S_dec_u8: torch.Tensor + amax_64: torch.Tensor + + +@dataclass +class HierarchicalNVFP4Colwise: + m: int + k: int + data_u8: torch.Tensor + S_enc: torch.Tensor + S_dec_u8: torch.Tensor + amax_64: torch.Tensor + + +def _amax_64_k(x: torch.Tensor) -> torch.Tensor: + m, k = int(x.size(0)), int(x.size(1)) + n64 = (k + COARSE - 1) // COARSE + a = x.new_empty(m, n64) + for t64 in range(n64): + lo, hi = t64 * COARSE, min((t64 + 1) * COARSE, k) + a[:, t64] = x[:, lo:hi].abs().max(dim=1).values + return a + + +def quantize_rowwise_1x64_1x16(x: torch.Tensor, eps: float = 1e-12) -> HierarchicalNVFP4Rowwise: + assert x.dim() == 2 + m, k = int(x.size(0)), int(x.size(1)) + device, dtype = x.device, x.dtype + x = x.to(torch.float32) + n16 = (k + FINE - 1) // FINE + amax_64 = _amax_64_k(x) + n64 = amax_64.size(1) + S_enc = torch.empty(m, n64, device=device, dtype=torch.float32) + for ri in range(m): + for t64 in range(n64): + S_enc[ri, t64] = _s_enc_from_amax_cpu(float(amax_64[ri, t64].item())) + S_dec_u8 = torch.empty(m, n16, device=device, dtype=torch.uint8) + w = x.clone() + grid = fp4_e2m1_grid_torch(device, torch.float32) + for row in range(m): + t16b = 0 + while t16b * FINE < k: + lo, hi = t16b * FINE, min((t16b + 1) * FINE, k) + t64 = lo // COARSE + s_e = S_enc[row, t64] + segx = x[row, lo:hi] + bamax = float(segx.abs().max().item()) + if bamax < eps: + bamax = float(eps) + raw = compute_S_dec_f32_before_cast_te(bamax, float(s_e.item())) + u = int(_f32_e4m3_u8_np(np.array([raw], dtype=np.float32).reshape(1))[0]) + S_dec_u8[row, t16b] = u + s_dec_f = max(float(_e4m3_u8_f32_np(np.array([u], dtype=np.uint8))[0]), TINY) + bsi = float(s_e.item()) / s_dec_f + w[row, lo:hi] = segx * bsi + t16b += 1 + q = _round_to_nearest_fp4(w, grid) + return HierarchicalNVFP4Rowwise(m, k, _pack_fp4_along_k(q), S_enc, S_dec_u8, amax_64) + + +def dequantize_rowwise(p: HierarchicalNVFP4Rowwise) -> torch.Tensor: + m, k, device = p.m, p.k, p.data_u8.device + q = _unpack_fp4_along_k(p.data_u8, k) + j16 = (torch.arange(k, device=device) // FINE).long() + j64 = (torch.arange(k, device=device) // COARSE).long() + sdec = torch.from_numpy(_e4m3_u8_f32_np(p.S_dec_u8[:, j16].cpu().numpy().astype(np.uint8))).to( + device=device, dtype=torch.float32 + ) + senc = p.S_enc[:, j64] + sdec = torch.clamp(sdec, min=TINY) + return (q * (sdec / senc)).to(torch.float32) + + +def _amax_64_m(x: torch.Tensor) -> torch.Tensor: + m, k = int(x.size(0)), int(x.size(1)) + n64 = (m + COARSE - 1) // COARSE + a = x.new_empty(n64, k) + for t64 in range(n64): + lo, hi = t64 * COARSE, min((t64 + 1) * COARSE, m) + a[t64, :] = x[lo:hi, :].abs().max(dim=0).values + return a + + +def quantize_columnwise_1x64_1x16(x: torch.Tensor, eps: float = 1e-12) -> HierarchicalNVFP4Colwise: + assert x.dim() == 2 + m, k = int(x.size(0)), int(x.size(1)) + device = x.device + x = x.to(torch.float32) + n16 = (m + FINE - 1) // FINE + amax_64 = _amax_64_m(x) + n64 = amax_64.size(0) + S_enc = torch.empty(n64, k, device=device, dtype=torch.float32) + for t64 in range(n64): + for col in range(k): + S_enc[t64, col] = _s_enc_from_amax_cpu(float(amax_64[t64, col].item())) + S_dec_u8 = torch.empty(n16, k, device=device, dtype=torch.uint8) + w = x.clone() + grid = fp4_e2m1_grid_torch(device, torch.float32) + for col in range(k): + t16b = 0 + while t16b * FINE < m: + lo, hi = t16b * FINE, min((t16b + 1) * FINE, m) + t64 = lo // COARSE + s_e = S_enc[t64, col] + segx = x[lo:hi, col] + bamax = float(segx.abs().max().item()) + if bamax < eps: + bamax = float(eps) + raw = compute_S_dec_f32_before_cast_te(bamax, float(s_e.item())) + u = int(_f32_e4m3_u8_np(np.array([raw], dtype=np.float32).reshape(1))[0]) + S_dec_u8[t16b, col] = u + s_dec_f = max(float(_e4m3_u8_f32_np(np.array([u], dtype=np.uint8))[0]), TINY) + bsi = float(s_e.item()) / s_dec_f + w[lo:hi, col] = segx * bsi + t16b += 1 + q = _round_to_nearest_fp4(w, grid) + return HierarchicalNVFP4Colwise(m, k, _pack_fp4_along_m(q), S_enc, S_dec_u8, amax_64) + + +def dequantize_colwise(p: HierarchicalNVFP4Colwise) -> torch.Tensor: + m, k, device = p.m, p.k, p.data_u8.device + q = _unpack_fp4_along_m(p.data_u8, m, k) + r16 = (torch.arange(m, device=device) // FINE).long() + r64 = (torch.arange(m, device=device) // COARSE).long() + sdec = torch.from_numpy(_e4m3_u8_f32_np(p.S_dec_u8[r16, :].cpu().numpy().astype(np.uint8))).to( + device=device, dtype=torch.float32 + ) + senc = p.S_enc[r64, :] + sdec = torch.clamp(sdec, min=TINY) + return (q * (sdec / senc)).to(torch.float32) + + +def reference_matmul_tn( + a_rows: HierarchicalNVFP4Rowwise, + b_cols: HierarchicalNVFP4Colwise, +) -> torch.Tensor: + return dequantize_rowwise(a_rows) @ dequantize_colwise(b_cols).T + + +def roundtrip_error(x: torch.Tensor, mode: str) -> Tuple[torch.Tensor, torch.Tensor]: + if mode == "rowwise": + p = quantize_rowwise_1x64_1x16(x) + y = dequantize_rowwise(p) + elif mode == "colwise": + p = quantize_columnwise_1x64_1x16(x) + y = dequantize_colwise(p) + else: + raise ValueError("mode is rowwise or colwise") + e = (x.to(torch.float32) - y).abs().max() + return y, e + + +if __name__ == "__main__": + torch.manual_seed(0) + m, n, kdim = 4, 5, 128 + a = torch.randn(m, kdim) + b = torch.randn(n, kdim) + pa, pb = quantize_rowwise_1x64_1x16(a), quantize_columnwise_1x64_1x16(b) + y = reference_matmul_tn(pa, pb) + y_ref = a @ b.T + err = (y - y_ref).abs().max() + print("matmul ref max abs err (via dequant):", float(err)) + _, e = roundtrip_error(a, "rowwise") + print("rowwise roundtrip max abs err:", float(e)) diff --git a/reference_hierarchical_nvfp4/hierarchical_nvfp4_ref_numpy.py b/reference_hierarchical_nvfp4/hierarchical_nvfp4_ref_numpy.py new file mode 100644 index 0000000000..970ac41917 --- /dev/null +++ b/reference_hierarchical_nvfp4/hierarchical_nvfp4_ref_numpy.py @@ -0,0 +1,295 @@ +# CPU reference aligned with TE NVFP4 (core_nvfp4.cuh + quantize_nvfp4.cuh path) except: +# S_enc = (fp8_max * fp4_max) / amax uses **1x64 local amax** instead of per-tensor global amax. +# +# Stages: S_enc(1x64) -> S_dec = cast_fp8( block_amax_1x16 * S_enc / 6 ) -> bsi = S_enc / f32(S_dec) +# -> t = x * bsi -> FP4 E2M1 (nearest). Dequant: x_hat = q_fp4 * (f32(S_dec) / S_enc). +# +# Requires: numpy. Run: python TransformerEngine/reference_hierarchical_nvfp4/hierarchical_nvfp4_ref_numpy.py + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Tuple + +import numpy as np + +try: + from .fp8_e4m3_utils_np import ( + TINY, + compute_S_dec_f32_before_cast_te, + compute_S_enc_from_amax_1x64_like_te, + e4m3_u8_to_f32, + f32_to_e4m3_u8, + ) +except ImportError: + import os + import sys + + sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) + from fp8_e4m3_utils_np import ( + TINY, + compute_S_dec_f32_before_cast_te, + compute_S_enc_from_amax_1x64_like_te, + e4m3_u8_to_f32, + f32_to_e4m3_u8, + ) + +COARSE = 64 +FINE = 16 + +FP4_E2M1_GRID = np.array( + [ + 0.0, + 0.5, + 1.0, + 1.5, + 2.0, + 3.0, + 4.0, + 6.0, + -0.0, + -0.5, + -1.0, + -1.5, + -2.0, + -3.0, + -4.0, + -6.0, + ], + dtype=np.float32, +) + + +def _round_to_nearest_fp4(x: np.ndarray) -> np.ndarray: + d = np.abs(x[..., None] - FP4_E2M1_GRID) + return FP4_E2M1_GRID[np.argmin(d, axis=-1).astype(np.int64)] + + +def _pack_fp4_along_k(x_fp4: np.ndarray) -> np.ndarray: + m, k = x_fp4.shape + d = np.abs(x_fp4[..., None] - FP4_E2M1_GRID) + nibble = np.argmin(d, axis=-1).astype(np.uint8) + n_pairs = (k + 1) // 2 + out = np.zeros((m, n_pairs), dtype=np.uint8) + for p in range(k // 2): + j = 2 * p + out[:, p] = (nibble[:, j] & 0x0F) | ((nibble[:, j + 1] & 0x0F) << 4) + if k % 2 == 1: + out[:, -1] = nibble[:, -1] & 0x0F + return out + + +def _unpack_fp4_along_k(data_u8: np.ndarray, k: int) -> np.ndarray: + m = data_u8.shape[0] + out = np.empty((m, k), dtype=np.float32) + p = 0 + for j in range(0, k - 1, 2): + b = data_u8[:, p] + p += 1 + out[:, j] = FP4_E2M1_GRID[(b & 0x0F).astype(np.int64)] + out[:, j + 1] = FP4_E2M1_GRID[((b >> 4) & 0x0F).astype(np.int64)] + if k % 2 == 1: + b = data_u8[:, p] + out[:, -1] = FP4_E2M1_GRID[(b & 0x0F).astype(np.int64)] + return out + + +def _pack_fp4_along_m(x_fp4: np.ndarray) -> np.ndarray: + m, k = x_fp4.shape + d = np.abs(x_fp4[..., None] - FP4_E2M1_GRID) + nibble = np.argmin(d, axis=-1).astype(np.uint8) + n_pairs = (m + 1) // 2 + out = np.zeros((n_pairs, k), dtype=np.uint8) + for p in range(m // 2): + r = 2 * p + out[p, :] = (nibble[r, :] & 0x0F) | ((nibble[r + 1, :] & 0x0F) << 4) + if m % 2 == 1: + out[-1, :] = nibble[-1, :] & 0x0F + return out + + +def _unpack_fp4_along_m(data_u8: np.ndarray, m: int, k: int) -> np.ndarray: + out = np.empty((m, k), dtype=np.float32) + p = 0 + for r in range(0, m - 1, 2): + b = data_u8[p, :] + p += 1 + out[r, :] = FP4_E2M1_GRID[(b & 0x0F).astype(np.int64)] + out[r + 1, :] = FP4_E2M1_GRID[((b >> 4) & 0x0F).astype(np.int64)] + if m % 2 == 1: + b = data_u8[p, :] + out[-1, :] = FP4_E2M1_GRID[(b & 0x0F).astype(np.int64)] + return out + + +@dataclass +class HierarchicalNVFP4RowwiseNp: + m: int + k: int + data_u8: np.ndarray + S_enc: np.ndarray + S_dec_u8: np.ndarray + amax_64: np.ndarray + + +@dataclass +class HierarchicalNVFP4ColwiseNp: + m: int + k: int + data_u8: np.ndarray + S_enc: np.ndarray + S_dec_u8: np.ndarray + amax_64: np.ndarray + + +def _amax_64_k(x: np.ndarray, m: int, k: int) -> np.ndarray: + n64 = (k + COARSE - 1) // COARSE + out = np.empty((m, n64), dtype=np.float32) + for row in range(m): + for t64 in range(n64): + lo, hi = t64 * COARSE, min((t64 + 1) * COARSE, k) + seg = x[row, lo:hi] + out[row, t64] = float(np.max(np.abs(seg))) if seg.size else 0.0 + return out + + +def quantize_rowwise_1x64_1x16(x: np.ndarray, eps: float = 1e-12) -> HierarchicalNVFP4RowwiseNp: + x = np.asarray(x, dtype=np.float32) + assert x.ndim == 2 + m, k = int(x.shape[0]), int(x.shape[1]) + n16 = (k + FINE - 1) // FINE + amax_64 = _amax_64_k(x, m, k) + n64 = amax_64.shape[1] + S_enc = np.empty((m, n64), dtype=np.float32) + for ri in range(m): + for t64 in range(n64): + S_enc[ri, t64] = compute_S_enc_from_amax_1x64_like_te(float(amax_64[ri, t64])) + S_dec_u8 = np.empty((m, n16), dtype=np.uint8) + w = np.empty_like(x) + for row in range(m): + t16b = 0 + while t16b * FINE < k: + lo, hi = t16b * FINE, min((t16b + 1) * FINE, k) + t64 = lo // COARSE + S = float(S_enc[row, t64]) + segx = x[row, lo:hi] + bamax = float(np.max(np.abs(segx))) + if bamax < eps: + bamax = float(eps) + raw = compute_S_dec_f32_before_cast_te(bamax, S) + u = f32_to_e4m3_u8(np.array(raw, dtype=np.float32).reshape(1)) + S_dec_u8[row, t16b] = u.ravel()[0] + s_dec_f = max(float(e4m3_u8_to_f32(S_dec_u8[row, t16b : t16b + 1])[0]), TINY) + bsi = S / s_dec_f + w[row, lo:hi] = segx * bsi + t16b += 1 + q = _round_to_nearest_fp4(w) + return HierarchicalNVFP4RowwiseNp(m, k, _pack_fp4_along_k(q), S_enc, S_dec_u8, amax_64) + + +def dequantize_rowwise(p: HierarchicalNVFP4RowwiseNp) -> np.ndarray: + m, k = p.m, p.k + q = _unpack_fp4_along_k(p.data_u8, k) + t16 = (np.arange(k) // FINE).astype(np.int64) + t64 = (np.arange(k) // COARSE).astype(np.int64) + sdec = e4m3_u8_to_f32(p.S_dec_u8[:, t16].astype(np.uint8)) + senc = p.S_enc[:, t64] + sdec = np.maximum(sdec, TINY) + return (q * (sdec / senc)).astype(np.float32) + + +def _amax_64_m(x: np.ndarray, m: int, k: int) -> np.ndarray: + n64 = (m + COARSE - 1) // COARSE + out = np.empty((n64, k), dtype=np.float32) + for col in range(k): + for t64 in range(n64): + lo, hi = t64 * COARSE, min((t64 + 1) * COARSE, m) + seg = x[lo:hi, col] + out[t64, col] = float(np.max(np.abs(seg))) if seg.size else 0.0 + return out + + +def quantize_columnwise_1x64_1x16(x: np.ndarray, eps: float = 1e-12) -> HierarchicalNVFP4ColwiseNp: + x = np.asarray(x, dtype=np.float32) + assert x.ndim == 2 + m, k = int(x.shape[0]), int(x.shape[1]) + n16 = (m + FINE - 1) // FINE + amax_64 = _amax_64_m(x, m, k) + n64 = amax_64.shape[0] + S_enc = np.empty((n64, k), dtype=np.float32) + for t64 in range(n64): + for col in range(k): + S_enc[t64, col] = compute_S_enc_from_amax_1x64_like_te(float(amax_64[t64, col])) + S_dec_u8 = np.empty((n16, k), dtype=np.uint8) + w = np.empty_like(x) + for col in range(k): + t16b = 0 + while t16b * FINE < m: + lo, hi = t16b * FINE, min((t16b + 1) * FINE, m) + t64 = lo // COARSE + S = float(S_enc[t64, col]) + segx = x[lo:hi, col] + bamax = float(np.max(np.abs(segx))) + if bamax < eps: + bamax = float(eps) + raw = compute_S_dec_f32_before_cast_te(bamax, S) + u = f32_to_e4m3_u8(np.array(raw, dtype=np.float32).reshape(1)) + S_dec_u8[t16b, col] = u.ravel()[0] + s_dec_f = max( + float(e4m3_u8_to_f32(S_dec_u8[t16b : t16b + 1, col : col + 1])[0, 0]), TINY + ) + bsi = S / s_dec_f + w[lo:hi, col] = segx * bsi + t16b += 1 + q = _round_to_nearest_fp4(w) + return HierarchicalNVFP4ColwiseNp(m, k, _pack_fp4_along_m(q), S_enc, S_dec_u8, amax_64) + + +def dequantize_colwise(p: HierarchicalNVFP4ColwiseNp) -> np.ndarray: + m, k = p.m, p.k + q = _unpack_fp4_along_m(p.data_u8, m, k) + r16 = (np.arange(m) // FINE).astype(np.int64) + r64 = (np.arange(m) // COARSE).astype(np.int64) + sdec = e4m3_u8_to_f32(p.S_dec_u8[r16, :].astype(np.uint8)) + senc = p.S_enc[r64, :] + sdec = np.maximum(sdec, TINY) + return (q * (sdec / senc)).astype(np.float32) + + +def reference_matmul_tn( + a_rows: HierarchicalNVFP4RowwiseNp, + b_cols: HierarchicalNVFP4ColwiseNp, +) -> np.ndarray: + return dequantize_rowwise(a_rows) @ dequantize_colwise(b_cols).T + + +def roundtrip_error(x: np.ndarray, mode: str) -> Tuple[np.ndarray, float]: + if mode == "rowwise": + p = quantize_rowwise_1x64_1x16(x) + y = dequantize_rowwise(p) + elif mode == "colwise": + p = quantize_columnwise_1x64_1x16(x) + y = dequantize_colwise(p) + else: + raise ValueError("mode is rowwise or colwise") + e = float(np.max(np.abs(x.astype(np.float32) - y))) + return y, e + + +def _self_test() -> None: + rng = np.random.default_rng(0) + m, n, kdim = 4, 5, 128 + a = rng.standard_normal((m, kdim)).astype(np.float32) + b = rng.standard_normal((n, kdim)).astype(np.float32) + pa, pb = quantize_rowwise_1x64_1x16(a), quantize_columnwise_1x64_1x16(b) + y = reference_matmul_tn(pa, pb) + a @ b.T + # loose bound (aggressive low precision) + assert np.isfinite(y).all() + assert dequantize_rowwise(pa).shape == a.shape + assert dequantize_colwise(pb).shape == b.shape + print("hierarchical_nvfp4_ref_numpy (TE1x64 path): _self_test OK") + + +if __name__ == "__main__": + _self_test() diff --git a/tests/pytorch/nvfp4/test_nvfp4_1x64_accuracy.py b/tests/pytorch/nvfp4/test_nvfp4_1x64_accuracy.py new file mode 100644 index 0000000000..84361a1456 --- /dev/null +++ b/tests/pytorch/nvfp4/test_nvfp4_1x64_accuracy.py @@ -0,0 +1,387 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Accuracy comparison: hierarchical 1x64 + 1x16 vs per-tensor + 1x16. + +The hypothesis under test is that per-1x64-window ``S_enc`` (the hierarchical +scheme) reconstructs ``x`` at least as accurately as per-tensor ``S_enc`` +(the production NVFP4 scheme), and strictly better when the K-direction +magnitude varies meaningfully across windows -- which is the case the +hierarchical scheme is built for. + +Methodology +----------- +This is a *spec-level* accuracy test, not a kernel test: + +1. Quantize the same input through two pure-PyTorch references -- + ``NVFP4QuantizerRef`` (per-tensor) and ``NVFP4Quantizer1x64Ref`` + (per-1x64-window) -- to obtain ``(qx, sx)`` for each. +2. Dequantize both back to fp32 using each scheme's *own* ``S_enc``: + ``x_recon = q * s_dec_fp32 / S_enc``. For the per-tensor scheme + ``S_enc`` is a scalar; for 1x64 it is per-window. (Using the per-tensor + ``S_enc`` to dequantize a 1x64-encoded tensor would re-introduce the + GEMM-mismatch bug -- a separate concern that is documented elsewhere.) +3. Compute reconstruction error metrics (RMSE, max abs, Frobenius-relative) + against the original fp32 input. +4. Assert ``rmse_1x64`` is at most a small slack worse than ``rmse_pt`` on + benign inputs, and strictly better on inputs with per-window dynamic + range. Numbers are also printed so a regression failure is diagnostic. + +Because the bit-exact tests in ``test_nvfp4_1x64_quantize_exact.py`` already +certify ``NVFP4Quantizer1x64Ref`` reproduces the CUDA kernel byte-for-byte, +any accuracy conclusion we draw at the reference level transfers directly +to the kernel. +""" + +from __future__ import annotations + +from typing import Tuple + +import pytest +import torch + +from transformer_engine.pytorch.custom_recipes import utils +from transformer_engine.pytorch.custom_recipes.quantization_nvfp4 import ( + NVFP4QuantizerRef, + cast_from_fp4x2, +) +from transformer_engine.pytorch.custom_recipes.quantization_nvfp4_1x64 import ( + BLOCK_K, + FLOAT4_E2M1_MAX, + FLOAT8_E4M3_MAX, + NVFP4Quantizer1x64Ref, + WINDOW_K, +) + + +_SAFE_AMAX_FLOOR = 1e-12 + + +def _per_tensor_s_enc(x: torch.Tensor) -> torch.Tensor: + """Per-tensor encoding scaling factor used by stock NVFP4. + + Returns a per-element broadcast of shape ``(M, N)`` so the dequant + formula can be expressed elementwise without further care for layout. + """ + M, N = x.shape + fp32_max = torch.tensor(torch.finfo(torch.float32).max, device=x.device, dtype=torch.float32) + global_amax = torch.amax(torch.abs(x.to(torch.float32))) + s = (FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.clamp(global_amax, min=_SAFE_AMAX_FLOOR) + s = torch.minimum(s, fp32_max) + if float(s.item()) == 0.0: + s = torch.ones_like(s) + return s.expand(M, N).contiguous() + + +def _per_window_s_enc(x: torch.Tensor) -> torch.Tensor: + """Per-1x64-window encoding scaling factor, broadcast to ``(M, N)``.""" + M, N = x.shape + pad_n = (WINDOW_K - N % WINDOW_K) % WINDOW_K + if pad_n > 0: + x_padded = torch.nn.functional.pad(x, (0, pad_n), mode="constant", value=0.0) + else: + x_padded = x.contiguous() + Np = x_padded.shape[1] + n_win = Np // WINDOW_K + + fp32_max = torch.tensor(torch.finfo(torch.float32).max, device=x.device, dtype=torch.float32) + x_padded_fp32 = x_padded.to(torch.float32).view(M, n_win, WINDOW_K) + tile_amax = torch.amax(torch.abs(x_padded_fp32), dim=-1, keepdim=True) + tile_amax_safe = torch.clamp(tile_amax, min=_SAFE_AMAX_FLOOR) + + s = (FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / tile_amax_safe + s = torch.minimum(s, fp32_max) + s = torch.where( + (tile_amax_safe == 0) | (s == 0), + torch.ones_like(s), + s, + ) + s_per_elt = s.squeeze(-1).repeat_interleave(WINDOW_K, dim=1) + return s_per_elt[:, :N].contiguous() + + +def _dequantize( + qx_packed: torch.Tensor, + sx_e4m3: torch.Tensor, + s_enc_per_elt: torch.Tensor, + M: int, + N: int, +) -> torch.Tensor: + """Inverse of NVFP4 forward quantization: ``q * s_dec_fp32 / S_enc``.""" + q_fp32 = cast_from_fp4x2(qx_packed.view(torch.uint8), torch.float32) + sx_fp32 = sx_e4m3.view(torch.float8_e4m3fn).to(torch.float32) + sx_fp32 = sx_fp32[:M, : N // BLOCK_K] + sx_per_elt = sx_fp32.repeat_interleave(BLOCK_K, dim=1) + return q_fp32 * sx_per_elt / s_enc_per_elt + + +def _recon_per_tensor(x: torch.Tensor) -> torch.Tensor: + """Forward + inverse via the production per-tensor NVFP4 reference.""" + M, N = x.shape + quantizer = NVFP4QuantizerRef( + dtype=utils.Fp4Formats.E2M1, + rowwise=True, + columnwise=False, + pow_2_scales=False, + eps=0.0, + quant_tile_shape=(1, 16), + with_rht=False, + ) + out = quantizer.quantize(x) + s_enc = _per_tensor_s_enc(x) + return _dequantize(out.data, out.scale, s_enc, M, N) + + +def _recon_1x64(x: torch.Tensor) -> torch.Tensor: + """Forward + inverse via the hierarchical 1x64 reference.""" + M, N = x.shape + out = NVFP4Quantizer1x64Ref().quantize(x) + s_enc = _per_window_s_enc(x) + return _dequantize(out.data, out.scale, s_enc, M, N) + + +def _err_metrics(x: torch.Tensor, recon: torch.Tensor) -> Tuple[float, float, float]: + """Return ``(rmse, max_abs_err, frobenius_relative)`` of ``recon - x``.""" + x_fp32 = x.to(torch.float32) + diff = recon.to(torch.float32) - x_fp32 + rmse = torch.sqrt(torch.mean(diff * diff)).item() + max_err = torch.max(torch.abs(diff)).item() + denom = torch.linalg.norm(x_fp32).clamp(min=1e-30) + frob_rel = (torch.linalg.norm(diff) / denom).item() + return rmse, max_err, frob_rel + + +def _gen_gaussian(M: int, N: int, *, seed: int, device: str, dtype: torch.dtype) -> torch.Tensor: + """Uniform N(0, 1) -- a benign baseline where both schemes should tie.""" + g = torch.Generator(device=device).manual_seed(seed) + return torch.randn((M, N), generator=g, device=device, dtype=dtype) + + +def _gen_per_window_dynamic_range( + M: int, + N: int, + *, + seed: int, + device: str, + dtype: torch.dtype, + log10_lo: float, + log10_hi: float, +) -> torch.Tensor: + """Each 1x64 window has its own log-uniform magnitude scale. + + This generator is used to verify that 1x64 does **not regress** versus + per-tensor when per-window magnitude varies but stays inside E4M3's + representable range. It deliberately does not try to demonstrate a + large advantage: per-block ``s_dec`` already adapts to ``vec_max`` in + both schemes, so as long as ``s_dec`` does not underflow E4M3, FP4 + resolution per element is ``vec_max / 12`` regardless of the ``S_enc`` + choice, and the two schemes tie up to E4M3 rounding noise. + + To actually demonstrate the 1x64 advantage in absolute RMSE terms, see + ``_gen_sparse_extreme_outliers`` -- that distribution forces prod's + ``s_dec`` into the underflow regime for the bulk of the tensor. + """ + g = torch.Generator(device=device).manual_seed(seed) + n_win = (N + WINDOW_K - 1) // WINDOW_K + log_scales = torch.empty((M, n_win, 1), device=device, dtype=torch.float32) + log_scales.uniform_(log10_lo, log10_hi, generator=g) + scales = torch.pow(torch.tensor(10.0, device=device, dtype=torch.float32), log_scales) + base = torch.randn((M, n_win, WINDOW_K), generator=g, device=device, dtype=torch.float32) + x = (base * scales).reshape(M, n_win * WINDOW_K)[:, :N].contiguous() + return x.to(dtype) + + +def _gen_sparse_extreme_outliers( + M: int, + N: int, + *, + seed: int, + device: str, + dtype: torch.dtype, + outlier_mag: float = 1.0e6, + n_outlier_windows: int = 4, +) -> torch.Tensor: + """``N(0, 1)`` background plus a handful of extreme outlier elements. + + This is the distribution where 1x64's advantage is sharp in absolute + RMSE. The mechanism: + + * The per-tensor scheme derives ``S_enc_global`` from the global amax, + which the outliers drag up to ~``outlier_mag``. With + ``outlier_mag = 1e6``, ``S_enc_global ~ 2.7e-3``; for a non-outlier + 1x16 block (``vec_max ~ 2``) the resulting ``s_dec_fp32`` is + ``~9e-4``, **below E4M3's smallest subnormal (~2e-3)**. ``s_dec`` + rounds to zero, so reconstruction of every non-outlier block is + identically zero -- catastrophic loss for the bulk of the tensor. + + * The hierarchical scheme uses ``S_enc_tile`` per 1x64 window. In the + ~``M*n_win - n_outlier_windows`` non-outlier windows ``tile_amax`` + sees only ``N(0, 1)``, so ``S_enc_tile ~ 900``, ``s_dec`` stays well + inside E4M3, and FP4 quantization runs at standard precision. + Loss is restricted to the (small) bulk inside outlier-containing + windows. + + The expected RMSE ratio ``rmse_1x64 / rmse_pt`` is ``~ 0.06`` -- well + below the 0.5 threshold the test asserts. + """ + g = torch.Generator(device=device).manual_seed(seed) + x = torch.randn((M, N), generator=g, device=device, dtype=torch.float32) + + n_win = N // WINDOW_K + if n_win == 0: + # Tensor is narrower than one full 1x64 window; the scheme degenerates + # and there is no meaningful advantage to test. + return x.to(dtype) + + total_windows = M * n_win + n_outlier_windows = min(n_outlier_windows, total_windows) + flat_idx = torch.randperm(total_windows, generator=g, device=device)[:n_outlier_windows] + outlier_rows = flat_idx // n_win + outlier_cols = (flat_idx % n_win) * WINDOW_K + signs = ( + torch.randint(0, 2, (n_outlier_windows,), generator=g, device=device, dtype=torch.int32).to( + torch.float32 + ) + * 2 + - 1 + ) + x[outlier_rows, outlier_cols] = signs * float(outlier_mag) + + return x.to(dtype) + + +_NEEDS_CUDA = pytest.mark.skipif( + not torch.cuda.is_available(), + reason="accuracy comparison runs on CUDA to mirror the rest of the nvfp4 suite", +) + + +@_NEEDS_CUDA +@pytest.mark.parametrize("M, N", [(256, 1024), (512, 2048), (1024, 1024)]) +@pytest.mark.parametrize("x_dtype", [torch.float32, torch.bfloat16], ids=str) +@pytest.mark.parametrize("seed", [0, 1, 2]) +def test_1x64_at_least_as_good_as_per_tensor_on_gaussian( + M: int, N: int, x_dtype: torch.dtype, seed: int, capsys +) -> None: + """On uniform Gaussian inputs the two schemes should be roughly tied. + + ``S_enc_global`` and per-window ``S_enc_tile`` see similar amax values + when the input magnitude is uniform across K, so we expect ratios near + 1.0. We allow a 5% slack to absorb E4M3-rounding noise; a regression + that materially worsens 1x64 accuracy on benign inputs would still be + caught. + """ + device = "cuda" + x = _gen_gaussian(M, N, seed=seed, device=device, dtype=x_dtype) + + rmse_pt, max_pt, fro_pt = _err_metrics(x, _recon_per_tensor(x)) + rmse_1x64, max_1x64, fro_1x64 = _err_metrics(x, _recon_1x64(x)) + + with capsys.disabled(): + print( + f"\n[gaussian {M}x{N} {x_dtype} seed={seed}]" + f" rmse: pt={rmse_pt:.4e} 1x64={rmse_1x64:.4e}" + f" ratio={rmse_1x64 / max(rmse_pt, 1e-30):.3f} |" + f" max_abs: pt={max_pt:.4e} 1x64={max_1x64:.4e} |" + f" frob_rel: pt={fro_pt:.4e} 1x64={fro_1x64:.4e}" + ) + + assert rmse_1x64 <= rmse_pt * 1.05, ( + "1x64 RMSE unexpectedly worse than per-tensor on uniform input: " + f"rmse_1x64={rmse_1x64:.4e} > 1.05 * rmse_pt={rmse_pt:.4e}" + ) + + +@_NEEDS_CUDA +@pytest.mark.parametrize("M, N", [(256, 1024), (512, 2048)]) +@pytest.mark.parametrize("x_dtype", [torch.float32, torch.bfloat16], ids=str) +@pytest.mark.parametrize("seed", [0, 1, 2]) +def test_1x64_better_than_per_tensor_on_sparse_extreme_outliers( + M: int, N: int, x_dtype: torch.dtype, seed: int, capsys +) -> None: + """Sparse extreme outliers force prod's ``s_dec`` into E4M3 underflow. + + Per-block ``s_dec`` already adapts to local ``vec_max`` in both + schemes, so within E4M3's representable range FP4 resolution per + element is ``~ vec_max / 12`` independent of the ``S_enc`` choice -- + the schemes tie. The 1x64 advantage in *absolute* RMSE terms shows up + only when prod's E4M3 ``s_dec`` underflows, and only when those + underflowed positions also dominate ``mean(x^2)``. + + To put both conditions in effect simultaneously, this test uses an + ``N(0, 1)`` bulk plus a handful of outliers of magnitude ``1e6``: + + * ``S_enc_global = (FP8_MAX * FP4_MAX) / amax ~ 2.7e-3`` -- + ``s_dec`` for non-outlier blocks underflows; reconstruction of the + entire bulk collapses to zero. ``rmse_pt ~ 1.0``. + + * ``S_enc_tile`` for ~99% of windows is set by ``tile_amax ~ 3`` + (no outlier present); FP4 quantization runs at standard precision. + ``rmse_1x64 ~ 0.06``. + + The assertion only requires 1x64 to be strictly better than + per-tensor; the comparison values are printed for each parametrized + case so the actual margin (empirically ``rmse_1x64 / rmse_pt`` is + around ``0.05 - 0.10``) is visible in the test log. + """ + device = "cuda" + x = _gen_sparse_extreme_outliers(M, N, seed=seed, device=device, dtype=x_dtype) + + rmse_pt, max_pt, fro_pt = _err_metrics(x, _recon_per_tensor(x)) + rmse_1x64, max_1x64, fro_1x64 = _err_metrics(x, _recon_1x64(x)) + + with capsys.disabled(): + print( + f"\n[sparse_outliers {M}x{N} {x_dtype} seed={seed}]" + f" rmse: pt={rmse_pt:.4e} 1x64={rmse_1x64:.4e}" + f" ratio={rmse_1x64 / max(rmse_pt, 1e-30):.3f} |" + f" max_abs: pt={max_pt:.4e} 1x64={max_1x64:.4e} |" + f" frob_rel: pt={fro_pt:.4e} 1x64={fro_1x64:.4e}" + ) + + assert rmse_1x64 < rmse_pt, ( + "1x64 was not strictly better than per-tensor on " + "sparse-extreme-outlier input: " + f"rmse_1x64={rmse_1x64:.4e} >= rmse_pt={rmse_pt:.4e}" + ) + + +@_NEEDS_CUDA +@pytest.mark.parametrize("M, N", [(256, 1024)]) +@pytest.mark.parametrize("x_dtype", [torch.float32, torch.bfloat16], ids=str) +@pytest.mark.parametrize("seed", [0, 1, 2]) +def test_1x64_at_least_tied_on_modest_per_window_dynamic_range( + M: int, N: int, x_dtype: torch.dtype, seed: int, capsys +) -> None: + """A moderate per-window dynamic range still favours 1x64. + + Ratio range of ~30x is well within E4M3's representable scale range, + so the per-tensor scheme does not catastrophically underflow. The + advantage from local ``S_enc`` is therefore smaller -- but still + present. We assert merely "no worse" plus a generous 5% slack; the + strict inequality test above is the one that demonstrates the win, + while this case ensures the win does not invert when the dynamic + range shrinks. + """ + device = "cuda" + x = _gen_per_window_dynamic_range( + M, N, seed=seed, device=device, dtype=x_dtype, log10_lo=-1.5, log10_hi=0.5 + ) + + rmse_pt, max_pt, fro_pt = _err_metrics(x, _recon_per_tensor(x)) + rmse_1x64, max_1x64, fro_1x64 = _err_metrics(x, _recon_1x64(x)) + + with capsys.disabled(): + print( + f"\n[dyn_range_modest {M}x{N} {x_dtype} seed={seed}]" + f" rmse: pt={rmse_pt:.4e} 1x64={rmse_1x64:.4e}" + f" ratio={rmse_1x64 / max(rmse_pt, 1e-30):.3f} |" + f" max_abs: pt={max_pt:.4e} 1x64={max_1x64:.4e} |" + f" frob_rel: pt={fro_pt:.4e} 1x64={fro_1x64:.4e}" + ) + + assert rmse_1x64 <= rmse_pt * 1.05, ( + "1x64 unexpectedly worse than per-tensor on modest dynamic-range " + f"input: rmse_1x64={rmse_1x64:.4e} > 1.05 * rmse_pt={rmse_pt:.4e}" + ) diff --git a/tests/pytorch/nvfp4/test_nvfp4_1x64_quantize_exact.py b/tests/pytorch/nvfp4/test_nvfp4_1x64_quantize_exact.py new file mode 100644 index 0000000000..7a56b0f8c2 --- /dev/null +++ b/tests/pytorch/nvfp4/test_nvfp4_1x64_quantize_exact.py @@ -0,0 +1,225 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Bit-exact tests for the NVFP4 hierarchical 1x64 cast CUDA kernel. + +Methodology mirrors ``test_nvfp4_quantize_exact.py``: invoke the kernel +through ``NVFP4Quantizer`` (with the ``NVTE_NVFP4_ROWWISE_1X64_LOCAL_ENCODE=1`` +env flag enabling the alternate dispatch), and compare every output buffer +against a pure-PyTorch oracle (``NVFP4Quantizer1x64Ref``) byte-for-byte +(``atol=rtol=0``). + +The kernel writes up to six tensors per call: + +* rowwise FP4 data, rowwise per-1x16 E4M3 ``s_dec``, rowwise per-1x64 + window amax; +* columnwise (transposed) FP4 data, columnwise per-1x16 E4M3 ``s_dec``, + columnwise per-1x64 window amax. + +This file covers the rowwise-only, columnwise-only, and rowwise+columnwise +configurations -- the latter being the production-equivalent fused mode. +""" + +import pytest +import torch + +import transformer_engine.pytorch as te +import transformer_engine_torch as tex +from transformer_engine.pytorch import NVFP4Quantizer +from transformer_engine.pytorch.custom_recipes.quantization_nvfp4_1x64 import ( + NVFP4Quantizer1x64Ref, +) + + +recipe_available, reason_for_no_recipe = te.is_nvfp4_available(return_reason=True) + + +def unpack_fp4(x: torch.Tensor) -> torch.Tensor: + """Unpack two FP4 values per byte into one ``uint8`` value per element.""" + repeated = x.repeat_interleave(2, dim=1) + repeated[:, 0::2] &= 0x0F + repeated[:, 1::2] >>= 4 + return repeated + + +def _check_quantization_1x64_versus_reference_with_input( + x: torch.Tensor, + *, + rowwise: bool, + columnwise: bool, +) -> None: + """Quantize ``x`` through both kernel and reference and assert that every + requested output (data, scale, window-amax, on each requested direction) + matches bit-exactly. + """ + te_dtype = tex.DType.kFloat4E2M1 + + quantizer = NVFP4Quantizer( + fp4_dtype=te_dtype, + rowwise=rowwise, + columnwise=columnwise, + with_amax_reduction=False, + amax_reduction_group=None, + with_rht=False, + with_post_rht_amax=False, + with_2d_quantization=False, + ) + sut = quantizer(x) + ref = NVFP4Quantizer1x64Ref(rowwise=rowwise, columnwise=columnwise).quantize(x) + + if rowwise: + qx_sut = unpack_fp4(sut._rowwise_data.view(dtype=torch.uint8)) + qx_ref = unpack_fp4(ref.data.view(dtype=torch.uint8)) + sx_sut = sut._rowwise_scale_inv.view(dtype=torch.uint8) + sx_ref = ref.scale.view(dtype=torch.uint8) + # The kernel may pad qx/sx to alignment boundaries; only the + # un-padded prefix is meaningfully written. + torch.testing.assert_close( + qx_sut[: qx_ref.shape[0], : qx_ref.shape[1]], qx_ref, atol=0.0, rtol=0.0 + ) + torch.testing.assert_close( + sx_sut[: sx_ref.shape[0], : sx_ref.shape[1]], sx_ref, atol=0.0, rtol=0.0 + ) + torch.testing.assert_close(sut._amax_rowwise, ref.window_amax_row, atol=0.0, rtol=0.0) + + if columnwise: + qxt_sut = unpack_fp4(sut._columnwise_data.view(dtype=torch.uint8)) + qxt_ref = unpack_fp4(ref.columnwise_data.view(dtype=torch.uint8)) + sxt_sut = sut._columnwise_scale_inv.view(dtype=torch.uint8) + sxt_ref = ref.columnwise_scale.view(dtype=torch.uint8) + torch.testing.assert_close( + qxt_sut[: qxt_ref.shape[0], : qxt_ref.shape[1]], qxt_ref, atol=0.0, rtol=0.0 + ) + torch.testing.assert_close( + sxt_sut[: sxt_ref.shape[0], : sxt_ref.shape[1]], sxt_ref, atol=0.0, rtol=0.0 + ) + torch.testing.assert_close(sut._amax_columnwise, ref.window_amax_col, atol=0.0, rtol=0.0) + + +def _check_random( + x_dtype: torch.dtype, + M: int, + N: int, + *, + rowwise: bool, + columnwise: bool, +) -> None: + """Random-input variant. Seeds are fixed so failures reproduce.""" + device = "cuda" + seed = 0 + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + x = torch.randn((M, N), dtype=x_dtype, device=device) + _check_quantization_1x64_versus_reference_with_input(x, rowwise=rowwise, columnwise=columnwise) + + +# Shapes where both M and N are multiples of 64 -- the 1x64 hierarchy's +# strict alignment requirement (enforced by the dispatcher). +_SHAPES_64x64_MULTIPLE = [ + (64, 64), + (128, 128), + (256, 256), + (256, 512), + (1024, 256), + (256, 1024), + (2048, 2048), + (1024, 2048), + # Non-square shapes that exercise distinct row/col tile counts -- the + # columnwise pass uses M/64 windows, the rowwise uses N/64. + (64, 256), + (256, 64), + (192, 384), + (384, 192), +] + + +@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe) +@pytest.mark.parametrize("M, N", _SHAPES_64x64_MULTIPLE) +@pytest.mark.parametrize("x_dtype", [torch.float32, torch.bfloat16], ids=str) +def test_nvfp4_1x64_quantize_rowwise(monkeypatch, x_dtype: torch.dtype, M: int, N: int) -> None: + """Rowwise-only configuration -- preserves the original PR's coverage.""" + monkeypatch.setenv("NVTE_NVFP4_ROWWISE_1X64_LOCAL_ENCODE", "1") + monkeypatch.setenv("NVTE_NVFP4_DISABLE_RHT", "1") + _check_random(x_dtype, M, N, rowwise=True, columnwise=False) + + +@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe) +@pytest.mark.parametrize("M, N", _SHAPES_64x64_MULTIPLE) +@pytest.mark.parametrize("x_dtype", [torch.float32, torch.bfloat16], ids=str) +def test_nvfp4_1x64_quantize_columnwise(monkeypatch, x_dtype: torch.dtype, M: int, N: int) -> None: + """Columnwise-only -- exercises the transposed output path on its own.""" + monkeypatch.setenv("NVTE_NVFP4_ROWWISE_1X64_LOCAL_ENCODE", "1") + monkeypatch.setenv("NVTE_NVFP4_DISABLE_RHT", "1") + _check_random(x_dtype, M, N, rowwise=False, columnwise=True) + + +@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe) +@pytest.mark.parametrize("M, N", _SHAPES_64x64_MULTIPLE) +@pytest.mark.parametrize("x_dtype", [torch.float32, torch.bfloat16], ids=str) +def test_nvfp4_1x64_quantize_rowwise_columnwise( + monkeypatch, x_dtype: torch.dtype, M: int, N: int +) -> None: + """Fused rowwise+columnwise -- the production-equivalent configuration.""" + monkeypatch.setenv("NVTE_NVFP4_ROWWISE_1X64_LOCAL_ENCODE", "1") + monkeypatch.setenv("NVTE_NVFP4_DISABLE_RHT", "1") + _check_random(x_dtype, M, N, rowwise=True, columnwise=True) + + +@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe) +@pytest.mark.parametrize("x_dtype", [torch.float32, torch.bfloat16], ids=str) +@pytest.mark.parametrize( + "rowwise, columnwise", + [(True, False), (False, True), (True, True)], + ids=["row", "col", "rowcol"], +) +def test_nvfp4_1x64_quantize_extrema( + monkeypatch, x_dtype: torch.dtype, rowwise: bool, columnwise: bool +) -> None: + """Stress the saturating cast and ``S_enc`` clamps on each direction. + + Inputs: + * an outlier-heavy row (one ~FP4_MAX value per 1x64 K-window plus tiny + noise) -- exercises the rowwise ``s_dec`` saturation branch; + * an outlier-heavy column (one ~FP4_MAX value per 1x64 M-window plus + tiny noise) -- the columnwise mirror of the above; + * an all-zero region -- exercises the ``tile_amax == 0`` fallback that + promotes ``S_enc`` to 1.0 and the kernel's ``s_dec == 0`` + short-circuit (without which ``cvt.rn.satfinite.e2m1x4.f32(NaN)`` + on SM10 would saturate to ``+FP4_MAX = 0x7`` instead of ``0x0``); + * a uniform constant block -- ``s_dec`` should be the same E4M3 byte + for every block in the affected window. + """ + monkeypatch.setenv("NVTE_NVFP4_ROWWISE_1X64_LOCAL_ENCODE", "1") + monkeypatch.setenv("NVTE_NVFP4_DISABLE_RHT", "1") + + device = "cuda" + M, N = 128, 256 + + torch.manual_seed(0) + torch.cuda.manual_seed(0) + x = torch.randn((M, N), dtype=x_dtype, device=device) * 0.01 + + # Row 0: per-1x64-K-window outlier on the rowwise side. + for w in range(N // 64): + x[0, w * 64] = 5.5 + + # Col 0: per-1x64-M-window outlier on the columnwise side. Distinct row + # from row 0 so the two outlier patterns do not overlap. + for w in range(M // 64): + x[w * 64 + 1, 0] = 5.25 # row index != 0, col 0 + + # Row 1: all zero (degenerate window amax path on the rowwise side). + # We've already touched (1, 0) above so re-zero just (1, 1:) -- the + # block at (1, 0..15) still contains a single non-zero outlier 5.25, + # which is fine; the all-zero stress remains in (1, 16:). + x[1, 16:] = 0.0 + + # Col 1: all zero (degenerate window amax path on the columnwise side). + x[16:, 1] = 0.0 + + # Row 2: uniform constant -- exercises both directions' constant-block + # path simultaneously for that row's columns it spans. + x[2] = 0.75 + + _check_quantization_1x64_versus_reference_with_input(x, rowwise=rowwise, columnwise=columnwise) diff --git a/tests/pytorch/nvfp4/test_nvfp4_1x64_split_quantize.py b/tests/pytorch/nvfp4/test_nvfp4_1x64_split_quantize.py new file mode 100644 index 0000000000..e5bdc4b1dc --- /dev/null +++ b/tests/pytorch/nvfp4/test_nvfp4_1x64_split_quantize.py @@ -0,0 +1,354 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Bit-exact tests for NVFP4 1x64 local-encode through ``tex.split_quantize``. + +Together with ``test_nvfp4_1x64_quantize_exact.py`` (which covers single-tensor +quantize) these tests lock down the multi-chunk dispatch path that drives +``split_quantize_nvfp4_impl_helper`` (and its ``UNFUSED`` fallback) when the +1x64 local-encode flag is on. Two complementary axes are exercised: + +* **Axis A -- algorithm oracle.** + Compare the SUT (``tex.split_quantize``) against a pure-PyTorch oracle + (``NVFP4Quantizer1x64Ref``) byte-for-byte, applied per chunk. The oracle is + independent of any TransformerEngine CUDA kernel, so any deviation must + come from the kernel itself (rounding, scale derivation, packed FP4 + layout). This catches algorithmic regressions even when both C++ + dispatch paths are wired identically. + +* **Axis B -- wiring oracle.** + Compare the SUT against ``quantizers[i](chunk)`` running per chunk via + ``reference_group_quantize``. Both sides ultimately invoke the same + ``quantize_1x64_local_encode`` kernel, so their outputs must be + bit-identical. Any mismatch points at the split-quantize driver itself -- + buffer allocation (``BULK_NVFP4`` vs. ``UNFUSED``), per-chunk config + propagation, ``nvte_quantize_v2`` arguments, or the inner-dim alignment + fallback in the dispatcher. + +Both axes use ``atol=rtol=0``: the 1x64 path is deterministic (no +stochastic rounding, no RHT random sign mask), so any non-zero diff is a +real regression. +""" + +import pytest +import torch + +import transformer_engine.pytorch as te +import transformer_engine_torch as tex +from transformer_engine.pytorch import NVFP4Quantizer +from transformer_engine.pytorch.custom_recipes.quantization_nvfp4_1x64 import ( + NVFP4Quantizer1x64Ref, +) + +from nvfp4_utils import ( + assert_same_shape_and_dtype, + generate_split_sections, + get_nvfp4_scale_shape_no_padding, + reference_group_quantize, +) + + +recipe_available, reason_for_no_recipe = te.is_nvfp4_available(return_reason=True) + + +# ----------------------------------------------------------------------------- +# Shape matrix +# +# 1x64 hierarchy requires both M and N to be multiples of 64. The split-quantize +# dispatcher additionally has two routing branches that we want to exercise: +# +# * N % 128 == 0 -> ``QuantizationMethod::FUSED_NVFP4`` keeps the call on the +# fused ``split_quantize_nvfp4_impl_helper`` path. +# * N % 128 != 0 -> dispatcher mirrors the ``BULK_NVFP4`` ``% 128`` fallback +# and downgrades to ``QuantizationMethod::UNFUSED`` (per-tensor +# ``NVFP4Quantizer::quantize_impl`` loop). The kernel itself is the same; +# only the driver changes. +# +# Both branches are kept here so that a regression in either lane fails CI. +# ----------------------------------------------------------------------------- +_SHAPES_FUSED_PATH = [ + (256, 1024), + (1024, 256), + (2048, 2048), +] +_SHAPES_FALLBACK_PATH = [ + (1024, 320), # N % 128 != 0, hits the dispatcher's UNFUSED fallback. +] +_SHAPES = _SHAPES_FUSED_PATH + _SHAPES_FALLBACK_PATH + + +def _unpack_fp4(x: torch.Tensor) -> torch.Tensor: + """Unpack two FP4 values per byte into one ``uint8`` per element.""" + repeated = x.repeat_interleave(2, dim=1) + repeated[:, 0::2] &= 0x0F + repeated[:, 1::2] >>= 4 + return repeated + + +def _make_quantizers(num: int, *, rowwise: bool, columnwise: bool): + """Build a list of ``NVFP4Quantizer`` configured for 1x64-compatible mode. + + 1x64 is mutually exclusive with RHT and 2D quantization; the constructor + flags here mirror that constraint so we never feed an invalid combination + into either dispatch path. + """ + return [ + NVFP4Quantizer( + fp4_dtype=tex.DType.kFloat4E2M1, + rowwise=rowwise, + columnwise=columnwise, + with_amax_reduction=False, + amax_reduction_group=None, + with_rht=False, + with_post_rht_amax=False, + with_2d_quantization=False, + ) + for _ in range(num) + ] + + +def _make_input(M: int, N: int, dtype: torch.dtype) -> torch.Tensor: + torch.manual_seed(0) + torch.cuda.manual_seed(0) + return torch.randn((M, N), dtype=dtype, device="cuda") + + +# ----------------------------------------------------------------------------- +# Axis A: algorithm oracle. tex.split_quantize vs NVFP4Quantizer1x64Ref +# ----------------------------------------------------------------------------- +def _check_axis_a_against_oracle( + *, + M: int, + N: int, + dtype: torch.dtype, + split_sections: list[int], + rowwise: bool, + columnwise: bool, +) -> None: + x = _make_input(M, N, dtype) + chunks = torch.split(x, split_sections) + quantizers = _make_quantizers(len(split_sections), rowwise=rowwise, columnwise=columnwise) + + sut_outputs = tex.split_quantize(x, split_sections, quantizers) + + for i, (sut, chunk) in enumerate(zip(sut_outputs, chunks)): + if split_sections[i] == 0: + # Empty chunk: SUT just allocates zero-row buffers. The oracle's + # behaviour on an empty input is irrelevant -- we only need to + # know the kernel did not write past the chunk boundary. + if rowwise: + assert sut._rowwise_data.shape[0] == 0, f"chunk {i} rowwise data not empty" + assert sut._amax_rowwise.shape[0] == 0, f"chunk {i} rowwise amax not empty" + if columnwise: + # columnwise tensors are transposed: (N, M_chunk//2) etc. + assert sut._columnwise_data.shape[1] == 0, f"chunk {i} columnwise data not empty" + assert sut._amax_columnwise.shape[1] == 0, f"chunk {i} columnwise amax not empty" + continue + + ref = NVFP4Quantizer1x64Ref(rowwise=rowwise, columnwise=columnwise).quantize(chunk) + + if rowwise: + # Kernel pads qx/sx to alignment boundaries; only compare the + # un-padded prefix, exactly as test_nvfp4_1x64_quantize_exact.py + # does for the single-tensor path. + qx_sut = _unpack_fp4(sut._rowwise_data.view(dtype=torch.uint8)) + qx_ref = _unpack_fp4(ref.data.view(dtype=torch.uint8)) + sx_sut = sut._rowwise_scale_inv.view(dtype=torch.uint8) + sx_ref = ref.scale.view(dtype=torch.uint8) + torch.testing.assert_close( + qx_sut[: qx_ref.shape[0], : qx_ref.shape[1]], qx_ref, atol=0.0, rtol=0.0 + ) + torch.testing.assert_close( + sx_sut[: sx_ref.shape[0], : sx_ref.shape[1]], sx_ref, atol=0.0, rtol=0.0 + ) + torch.testing.assert_close(sut._amax_rowwise, ref.window_amax_row, atol=0.0, rtol=0.0) + + if columnwise: + qxt_sut = _unpack_fp4(sut._columnwise_data.view(dtype=torch.uint8)) + qxt_ref = _unpack_fp4(ref.columnwise_data.view(dtype=torch.uint8)) + sxt_sut = sut._columnwise_scale_inv.view(dtype=torch.uint8) + sxt_ref = ref.columnwise_scale.view(dtype=torch.uint8) + torch.testing.assert_close( + qxt_sut[: qxt_ref.shape[0], : qxt_ref.shape[1]], qxt_ref, atol=0.0, rtol=0.0 + ) + torch.testing.assert_close( + sxt_sut[: sxt_ref.shape[0], : sxt_ref.shape[1]], sxt_ref, atol=0.0, rtol=0.0 + ) + torch.testing.assert_close( + sut._amax_columnwise, ref.window_amax_col, atol=0.0, rtol=0.0 + ) + + +# ----------------------------------------------------------------------------- +# Axis B: wiring oracle. tex.split_quantize vs per-chunk quantizer(chunk) +# ----------------------------------------------------------------------------- +def _check_axis_b_against_per_tensor( + *, + M: int, + N: int, + dtype: torch.dtype, + split_sections: list[int], + rowwise: bool, + columnwise: bool, +) -> None: + x = _make_input(M, N, dtype) + chunks = torch.split(x, split_sections) + + # Build TWO independent quantizer lists -- one drives the SUT, one drives + # the per-tensor reference -- to keep the comparison strictly between the + # two C++ dispatch paths and not accidentally reuse any quantizer state. + sut_quantizers = _make_quantizers(len(split_sections), rowwise=rowwise, columnwise=columnwise) + ref_quantizers = _make_quantizers(len(split_sections), rowwise=rowwise, columnwise=columnwise) + + sut_outputs = tex.split_quantize(x, split_sections, sut_quantizers) + + # ``reference_group_quantize`` calls ``ref_quantizers[i](chunk)`` per + # chunk -- with the 1x64 env flag on, this still runs the same + # ``quantize_1x64_local_encode`` kernel, just driven through the + # single-tensor entry point. SUT and reference must agree bit-for-bit + # on every byte the kernel actually writes. + qx_ref, sx_ref, amax_row_ref, qxt_ref, sxt_ref, amax_col_ref = reference_group_quantize( + x, ref_quantizers, split_sections, rowwise, columnwise + ) + + # NVFP4 scale buffers are over-allocated (rounded up to the cuBLAS + # block-scaling-factor layout: 128 in the outer dim, 4 in the inner dim) + # so that swizzle can run in place. The kernel only writes the un-padded + # prefix; the padded tail is left in whatever state ``at::empty`` + # returned, which differs across allocations. Slice both sides down to + # the valid prefix before doing a bit-exact compare. Data buffers + # ``(M, N/2)`` and 1x64 amax buffers ``(M, N/64)`` / ``(N, M/64)`` are + # already exact-sized, so they can be compared whole. + if rowwise: + sut_qx = [out._rowwise_data.view(dtype=torch.uint8) for out in sut_outputs] + sut_sx = [out._rowwise_scale_inv for out in sut_outputs] + sut_amax = [out._amax_rowwise for out in sut_outputs] + for i in range(len(sut_outputs)): + if split_sections[i] == 0: + assert_same_shape_and_dtype(sut_qx[i], qx_ref[i]) + assert_same_shape_and_dtype(sut_sx[i], sx_ref[i]) + assert_same_shape_and_dtype(sut_amax[i], amax_row_ref[i]) + continue + torch.testing.assert_close(sut_qx[i], qx_ref[i], atol=0.0, rtol=0.0) + torch.testing.assert_close(sut_amax[i], amax_row_ref[i], atol=0.0, rtol=0.0) + valid = get_nvfp4_scale_shape_no_padding(chunks[i].shape, columnwise=False) + torch.testing.assert_close( + sut_sx[i][: valid[0], : valid[1]], + sx_ref[i][: valid[0], : valid[1]], + atol=0.0, + rtol=0.0, + ) + + if columnwise: + sut_qxt = [out._columnwise_data.view(dtype=torch.uint8) for out in sut_outputs] + sut_sxt = [out._columnwise_scale_inv for out in sut_outputs] + sut_amax_t = [out._amax_columnwise for out in sut_outputs] + for i in range(len(sut_outputs)): + if split_sections[i] == 0: + assert_same_shape_and_dtype(sut_qxt[i], qxt_ref[i]) + assert_same_shape_and_dtype(sut_sxt[i], sxt_ref[i]) + assert_same_shape_and_dtype(sut_amax_t[i], amax_col_ref[i]) + continue + torch.testing.assert_close(sut_qxt[i], qxt_ref[i], atol=0.0, rtol=0.0) + torch.testing.assert_close(sut_amax_t[i], amax_col_ref[i], atol=0.0, rtol=0.0) + valid = get_nvfp4_scale_shape_no_padding(chunks[i].shape, columnwise=True) + torch.testing.assert_close( + sut_sxt[i][: valid[0], : valid[1]], + sxt_ref[i][: valid[0], : valid[1]], + atol=0.0, + rtol=0.0, + ) + + +# ----------------------------------------------------------------------------- +# Pytest entry points +# ----------------------------------------------------------------------------- +@pytest.fixture +def _enable_1x64(monkeypatch): + """Enable 1x64 local-encode for both the SUT and the reference paths. + + The flag is read at every quantize call (it gates ``local_encode_from_env`` + in ``nvfp4_1x64.h``), so it must be set before any tensor flows through + either ``tex.split_quantize`` or ``quantizer(chunk)``. ``monkeypatch`` + scopes the change to a single test, so other tests in the same session + are unaffected. + """ + monkeypatch.setenv("NVTE_NVFP4_ROWWISE_1X64_LOCAL_ENCODE", "1") + monkeypatch.setenv("NVTE_NVFP4_DISABLE_RHT", "1") + + +@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe) +@pytest.mark.parametrize("M, N", _SHAPES) +@pytest.mark.parametrize( + "edge_cases", + [ + "regular", + "zero_tokens_front", + "zero_tokens_middle", + "random_uneven_split", + ], +) +@pytest.mark.parametrize( + "rowwise, columnwise", + [(True, False), (False, True), (True, True)], + ids=["row", "col", "rowcol"], +) +@pytest.mark.parametrize("x_dtype", [torch.bfloat16], ids=str) +def test_split_quantize_1x64_axis_a_oracle( + _enable_1x64, + x_dtype: torch.dtype, + M: int, + N: int, + edge_cases: str, + rowwise: bool, + columnwise: bool, +) -> None: + """Axis A: split-quantize output must equal the pure-PyTorch 1x64 oracle.""" + split_sections = generate_split_sections(M, N, edge_cases, least_multiple=64) + _check_axis_a_against_oracle( + M=M, + N=N, + dtype=x_dtype, + split_sections=split_sections, + rowwise=rowwise, + columnwise=columnwise, + ) + + +@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe) +@pytest.mark.parametrize("M, N", _SHAPES) +@pytest.mark.parametrize( + "edge_cases", + [ + "regular", + "zero_tokens_front", + "zero_tokens_middle", + "random_uneven_split", + ], +) +@pytest.mark.parametrize( + "rowwise, columnwise", + [(True, False), (False, True), (True, True)], + ids=["row", "col", "rowcol"], +) +@pytest.mark.parametrize("x_dtype", [torch.bfloat16], ids=str) +def test_split_quantize_1x64_axis_b_wiring( + _enable_1x64, + x_dtype: torch.dtype, + M: int, + N: int, + edge_cases: str, + rowwise: bool, + columnwise: bool, +) -> None: + """Axis B: split-quantize and per-tensor driver must produce identical bytes.""" + split_sections = generate_split_sections(M, N, edge_cases, least_multiple=64) + _check_axis_b_against_per_tensor( + M=M, + N=N, + dtype=x_dtype, + split_sections=split_sections, + rowwise=rowwise, + columnwise=columnwise, + ) diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index 3f684adbb4..034c113f70 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -219,6 +219,7 @@ list(APPEND transformer_engine_cuda_arch_specific_sources multi_tensor/compute_scale.cu recipe/mxfp8_scaling.cu recipe/nvfp4.cu + cast/nvfp4/quantize_nvfp4_1x64.cu transpose/quantize_transpose_square_blockwise.cu transpose/quantize_transpose_vector_blockwise_fp4.cu) diff --git a/transformer_engine/common/cast/dispatch/quantize.cuh b/transformer_engine/common/cast/dispatch/quantize.cuh index 8d985f64f3..f4e16077a0 100644 --- a/transformer_engine/common/cast/dispatch/quantize.cuh +++ b/transformer_engine/common/cast/dispatch/quantize.cuh @@ -22,6 +22,7 @@ #include "../mxfp8/quantize_mxfp8.cuh" #include "../nvfp4/group_quantize_transpose_nvfp4.cuh" #include "../nvfp4/quantize_nvfp4.cuh" +#include "../nvfp4/quantize_nvfp4_1x64.cuh" #include "../nvfp4/quantize_transpose_nvfp4.cuh" namespace transformer_engine { @@ -104,8 +105,19 @@ void quantize_fwd_helper(const NVTETensor input, NVTETensor output, bool use_optimized_kernel = (dtype == DType::kBFloat16) && (rows % 32 == 0) && (cols % 32 == 0) && output_tensor->has_data(); - // Launch NVFP4 quantize kernel - if (use_optimized_kernel) { + // Per-1x64-K-window S_enc (non-RHT). The kernel produces rowwise and/or + // columnwise outputs in a single fused tile pass; per-window amax is + // written into the existing ``amax`` / ``columnwise_amax`` slots, which + // the PyTorch wrapper allocates with shape (M, N/64) / (N, M/64) when + // ``nvfp4_1x64_local_encode`` is set. + if (quant_config_cpp.nvfp4_rowwise_1x64_local_encode) { + NVTE_CHECK(!quant_config_cpp.nvfp4_2d_quantization, + "NVFP4 1x64 local encode is incompatible with 2D block scaling."); + NVTE_CHECK(!quant_config_cpp.stochastic_rounding, + "NVFP4 1x64 local encode does not support stochastic rounding yet."); + nvfp4::quantize_1x64_local_encode(*input_tensor, *noop_tensor, output_tensor, + quant_config_cpp, stream); + } else if (use_optimized_kernel) { if (quant_config_cpp.nvfp4_2d_quantization) { nvfp4::quantize_transpose( *input_tensor, noop_tensor, output_tensor, &quant_config_cpp, stream); @@ -236,6 +248,10 @@ void quantize_bwd_helper(const NVTETensor grad, const NVTETensor input, NVTETens CheckInputTensor(*grad_tensor, "input"); CheckOutputTensor(*output_tensor, "output", false); + NVTE_CHECK( + !quant_config_cpp.nvfp4_rowwise_1x64_local_encode, + "NVFP4 rowwise 1x64 local encode is not implemented for backward quantization yet."); + // Choose kernel int32_t rows = grad_tensor->flat_first_dim(); int32_t cols = grad_tensor->flat_last_dim(); diff --git a/transformer_engine/common/cast/nvfp4/quantize_nvfp4_1x64.cu b/transformer_engine/common/cast/nvfp4/quantize_nvfp4_1x64.cu new file mode 100644 index 0000000000..d31bab6ee3 --- /dev/null +++ b/transformer_engine/common/cast/nvfp4/quantize_nvfp4_1x64.cu @@ -0,0 +1,291 @@ +/************************************************************************* + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include "common/cast/nvfp4/core_nvfp4.cuh" +#include "common/cast/nvfp4/quantize_nvfp4_1x64.cuh" +#include "common/common.h" +#include "common/util/ptx.cuh" +#include "common/utils.cuh" + +namespace transformer_engine { +namespace dispatch { +namespace nvfp4 { +namespace { + +#if FP4_TYPE_SUPPORTED + +using core::compute_global_encode_scaling_factor_FP4; +using ptx::FPx2; +using quantization_SF::compute_decoding_scaling_factor; + +// One CUDA block = one 64x64 input tile in (M, N) row-major space. +// +// Grid layout: (blockIdx.x, blockIdx.y) = (n_window, m_tile), so each CTA +// owns exactly one 1x64 K-window for the rowwise pass *and* one 1x64 +// (transposed-K) M-window for the columnwise pass. With 64 threads per +// CTA, threadIdx.x doubles as "row index in tile" during the rowwise pass +// and "column index in tile" during the columnwise pass. +// +// Either pass can be skipped at runtime by passing nullptr for its three +// output buffers; the SMEM tile load is shared. +// +// SMEM is padded to ``[64][65]`` so the columnwise transpose access +// (``in_sm[e][tid]`` walking down a column) does not fall on the same +// 32-bank lane for every row. +template +__global__ void __launch_bounds__(64) + nvfp4_1x64_fused_per_tile(const IType* __restrict__ in, const size_t rows, const size_t cols, + const int ld_row_elts, + // Rowwise outputs (all three are non-null together, or all null). + uint8_t* __restrict__ q_row, fp8e4m3* __restrict__ s_dec_row, + float* __restrict__ w_amax_row, const size_t s_dec_row_stride, + const size_t w_amax_row_stride, + // Columnwise (transposed) outputs (all three together, or all null). + uint8_t* __restrict__ q_col, fp8e4m3* __restrict__ s_dec_col, + float* __restrict__ w_amax_col, const size_t s_dec_col_stride, + const size_t w_amax_col_stride, const float* __restrict__ noop) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + if (noop != nullptr && noop[0] == 1.0f) { + return; + } + + const int tile_n = static_cast(blockIdx.x); + const int tile_m = static_cast(blockIdx.y); + const int tid = static_cast(threadIdx.x); // 0..63 + + const int row_base = tile_m * 64; + const int col_base = tile_n * 64; + if (row_base >= static_cast(rows) || col_base >= static_cast(cols)) { + return; + } + + const bool do_row = (q_row != nullptr); + const bool do_col = (q_col != nullptr); + + // 64x64 fp32 staging buffer with +1 column padding to side-step bank + // conflicts on the columnwise transpose access pattern. + __shared__ float in_sm[64][65]; + + // Cooperative load: thread tid loads its assigned row (64 elements). + // Padding with zeros keeps the amax reductions correct on M-tail rows + // (the dispatcher already guarantees full 64-aligned tiles, so this is + // strictly defensive). + { + const int gr = row_base + tid; + if (gr < static_cast(rows)) { +#pragma unroll + for (int e = 0; e < 64; e++) { + const int gc = col_base + e; + in_sm[tid][e] = static_cast(in[static_cast(gr) * ld_row_elts + gc]); + } + } else { +#pragma unroll + for (int e = 0; e < 64; e++) { + in_sm[tid][e] = 0.f; + } + } + } + __syncthreads(); + + using IType2 = FPx2; + + // ============================ ROWWISE PASS ============================ + if (do_row && (row_base + tid) < static_cast(rows)) { + const int r = row_base + tid; + + float wmx = 0.f; +#pragma unroll + for (int e = 0; e < 64; e++) { + wmx = fmaxf(wmx, fabsf(in_sm[tid][e])); + } + const float S_enc = compute_global_encode_scaling_factor_FP4(fmaxf(wmx, 1e-12f)); + + w_amax_row[static_cast(r) * w_amax_row_stride + tile_n] = wmx; + + uint8_t* row_out = q_row + static_cast(r) * (cols / 2); +#pragma unroll + for (int b = 0; b < 4; b++) { + float bmx = 0.f; + float vals[16]; +#pragma unroll + for (int e = 0; e < 16; e++) { + const float v = in_sm[tid][b * 16 + e]; + vals[e] = v; + bmx = fmaxf(bmx, fabsf(v)); + } + const fp8e4m3 s_dec = compute_decoding_scaling_factor(bmx, S_enc); + const float s_dec_f = static_cast(s_dec); + // Match the reference's all-zero-block branch (see + // ``NVFP4Quantizer1x64Ref``): when ``bmx == 0`` ``s_dec`` saturates + // to 0 and a naive ``S_enc / 0`` would NaN through the cvt. + const float block_scale = (s_dec_f == 0.f) ? 0.f : __fdiv_rn(S_enc, s_dec_f); + + const int sub_blk_global = (col_base + b * 16) / 16; + s_dec_row[static_cast(r) * s_dec_row_stride + sub_blk_global] = s_dec; + + const size_t byte_off = static_cast(col_base + b * 16) / 2; +#pragma unroll + for (int q = 0; q < 4; q++) { + const int e0 = q * 4; + IType2 in01{static_cast(vals[e0]), static_cast(vals[e0 + 1])}; + IType2 in23{static_cast(vals[e0 + 2]), static_cast(vals[e0 + 3])}; + fp4e2m1x4 qu{}; + ptx::mul_cvt_4x(qu, in01, in23, block_scale); + *reinterpret_cast(row_out + byte_off + static_cast(2 * q)) = qu; + } + } + } + + // No barrier between rowwise and colwise passes: ``in_sm`` is read-only + // after the initial load+sync, all writes go to disjoint global memory + // regions, so the two passes are safely concurrent at warp granularity. + + // =========================== COLUMNWISE PASS =========================== + if (do_col && (col_base + tid) < static_cast(cols)) { + const int c = col_base + tid; + + float wmx = 0.f; +#pragma unroll + for (int e = 0; e < 64; e++) { + wmx = fmaxf(wmx, fabsf(in_sm[e][tid])); + } + const float S_enc = compute_global_encode_scaling_factor_FP4(fmaxf(wmx, 1e-12f)); + + w_amax_col[static_cast(c) * w_amax_col_stride + tile_m] = wmx; + + // Transposed output: q_col is laid out as (N, M/2). Each "row" of the + // transposed tensor corresponds to one column of the input, so byte + // stride is rows/2 along original-M. + uint8_t* col_out = q_col + static_cast(c) * (rows / 2); +#pragma unroll + for (int b = 0; b < 4; b++) { + float bmx = 0.f; + float vals[16]; +#pragma unroll + for (int e = 0; e < 16; e++) { + const float v = in_sm[b * 16 + e][tid]; + vals[e] = v; + bmx = fmaxf(bmx, fabsf(v)); + } + const fp8e4m3 s_dec = compute_decoding_scaling_factor(bmx, S_enc); + const float s_dec_f = static_cast(s_dec); + const float block_scale = (s_dec_f == 0.f) ? 0.f : __fdiv_rn(S_enc, s_dec_f); + + const int sub_blk_global = (row_base + b * 16) / 16; + s_dec_col[static_cast(c) * s_dec_col_stride + sub_blk_global] = s_dec; + + const size_t byte_off = static_cast(row_base + b * 16) / 2; +#pragma unroll + for (int q = 0; q < 4; q++) { + const int e0 = q * 4; + IType2 in01{static_cast(vals[e0]), static_cast(vals[e0 + 1])}; + IType2 in23{static_cast(vals[e0 + 2]), static_cast(vals[e0 + 3])}; + fp4e2m1x4 qu{}; + ptx::mul_cvt_4x(qu, in01, in23, block_scale); + *reinterpret_cast(col_out + byte_off + static_cast(2 * q)) = qu; + } + } + } +#endif // __CUDA_ARCH__ >= 1000 +} + +#endif // FP4_TYPE_SUPPORTED + +} // namespace + +void quantize_1x64_local_encode(const Tensor& input, const Tensor& noop, Tensor* output, + const QuantizationConfig& /* quant_config */, cudaStream_t stream) { +#if FP4_TYPE_SUPPORTED + CheckNoopTensor(noop, "cast_noop"); + NVTE_CHECK(input.has_data(), "NVFP4 1x64: input has no data."); + NVTE_CHECK(output->has_data() || output->has_columnwise_data(), + "NVFP4 1x64: at least one of rowwise/columnwise output must be allocated."); + NVTE_CHECK(!output->with_gemm_swizzled_scales, + "NVFP4 1x64: expects compact (non-gemm) scales on both directions."); + + const size_t rows = input.flat_first_dim(); + const size_t cols = input.flat_last_dim(); + if (rows == 0 || cols == 0) { + return; + } + // 1x16 sub-block inside a 1x64 K-window: both dimensions must be multiples + // of 64 to keep all four sub-blocks of every window in-bounds for both + // directions. The PyTorch wrapper (NVFP4Quantizer::create_tensor) sizes the + // ``amax`` slot accordingly, so this check also pins down the Python side. + NVTE_CHECK(cols % 64 == 0, "NVFP4 1x64: K (cols) must be a multiple of 64, got: ", cols); + NVTE_CHECK(rows % 64 == 0, "NVFP4 1x64: M (rows) must be a multiple of 64, got: ", rows); + + uint8_t* q_row = nullptr; + fp8e4m3* s_dec_row = nullptr; + float* w_amax_row = nullptr; + size_t s_dec_row_stride = 0; + size_t w_amax_row_stride = 0; + if (output->has_data()) { + NVTE_CHECK(output->scale_inv.dptr != nullptr, + "NVFP4 1x64: rowwise scale_inv must be allocated."); + NVTE_CHECK(output->amax.dptr != nullptr, + "NVFP4 1x64: rowwise amax (per-window) buffer is required."); + q_row = reinterpret_cast(output->data.dptr); + s_dec_row = reinterpret_cast(output->scale_inv.dptr); + w_amax_row = reinterpret_cast(output->amax.dptr); + s_dec_row_stride = output->scale_inv.shape.size() > 1 ? output->scale_inv.shape[1] : 1; + w_amax_row_stride = output->amax.shape.size() > 1 ? output->amax.shape[1] : 1; + NVTE_CHECK(s_dec_row_stride == cols / 16, + "NVFP4 1x64: rowwise scale_inv stride must equal cols/16, got ", s_dec_row_stride); + NVTE_CHECK(w_amax_row_stride == cols / 64, + "NVFP4 1x64: rowwise amax stride must equal cols/64, got ", w_amax_row_stride); + } + + uint8_t* q_col = nullptr; + fp8e4m3* s_dec_col = nullptr; + float* w_amax_col = nullptr; + size_t s_dec_col_stride = 0; + size_t w_amax_col_stride = 0; + if (output->has_columnwise_data()) { + NVTE_CHECK(output->columnwise_scale_inv.dptr != nullptr, + "NVFP4 1x64: columnwise scale_inv must be allocated."); + NVTE_CHECK(output->columnwise_amax.dptr != nullptr, + "NVFP4 1x64: columnwise amax (per-window) buffer is required."); + q_col = reinterpret_cast(output->columnwise_data.dptr); + s_dec_col = reinterpret_cast(output->columnwise_scale_inv.dptr); + w_amax_col = reinterpret_cast(output->columnwise_amax.dptr); + s_dec_col_stride = + output->columnwise_scale_inv.shape.size() > 1 ? output->columnwise_scale_inv.shape[1] : 1; + w_amax_col_stride = + output->columnwise_amax.shape.size() > 1 ? output->columnwise_amax.shape[1] : 1; + NVTE_CHECK(s_dec_col_stride == rows / 16, + "NVFP4 1x64: columnwise scale_inv stride must equal rows/16, got ", + s_dec_col_stride); + NVTE_CHECK(w_amax_col_stride == rows / 64, + "NVFP4 1x64: columnwise amax stride must equal rows/64, got ", w_amax_col_stride); + } + + const size_t n_win = cols / 64; + const size_t m_tiles = rows / 64; + dim3 grid(static_cast(n_win), static_cast(m_tiles), 1); + constexpr int kBlock = 64; + + TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(input.dtype(), IType, { + const IType* in_t = reinterpret_cast(input.data.dptr); + nvfp4_1x64_fused_per_tile<<>>( + in_t, rows, cols, static_cast(cols), q_row, s_dec_row, w_amax_row, s_dec_row_stride, + w_amax_row_stride, q_col, s_dec_col, w_amax_col, s_dec_col_stride, w_amax_col_stride, + reinterpret_cast(noop.data.dptr)); + NVTE_CHECK_CUDA(cudaGetLastError()); + }); + +#else + (void)input; + (void)noop; + (void)output; + (void)stream; + NVTE_ERROR("FP4 support requires CUDA 12.8+, but compile-time CUDA version is ", CUDA_VERSION); +#endif +} + +} // namespace nvfp4 +} // namespace dispatch +} // namespace transformer_engine diff --git a/transformer_engine/common/cast/nvfp4/quantize_nvfp4_1x64.cuh b/transformer_engine/common/cast/nvfp4/quantize_nvfp4_1x64.cuh new file mode 100644 index 0000000000..ddc2a165b8 --- /dev/null +++ b/transformer_engine/common/cast/nvfp4/quantize_nvfp4_1x64.cuh @@ -0,0 +1,45 @@ +/************************************************************************* + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/*! \file quantize_nvfp4_1x64.cuh + * \brief NVFP4 hierarchical 1x64 cast (rowwise + optional transposed columnwise), + * with per-1x64-K-window S_enc and per-1x16 sub-block E4M3 s_dec. + * + * The kernel produces, for an (M, N) input, any non-empty subset of: + * * rowwise NVFP4 data + 1x16 E4M3 scales + (M, N/64) FP32 window amax + * * columnwise (transposed) NVFP4 data + 1x16 E4M3 scales + (N, M/64) FP32 window amax + * + * The "window amax" tensors are stored on the existing ``amax`` / + * ``columnwise_amax`` slots of the C++ ``Tensor`` (their shape is upgraded + * from a scalar (1,) -- as used by the per-tensor NVFP4 path -- to a 2D + * per-window buffer in 1x64 mode). Consumers that need the global tensor + * amax can take ``max`` over the per-window buffer at trivial cost. + * + * Non-RHT, non-2D, non-stochastic-rounding only. Both M and N are + * required to be multiples of 64 by the dispatcher. + */ +#ifndef TRANSFORMER_ENGINE_QUANTIZE_NVFP4_1X64_CUH_ +#define TRANSFORMER_ENGINE_QUANTIZE_NVFP4_1X64_CUH_ + +#include + +#include "../../common.h" + +namespace transformer_engine { +namespace dispatch { +namespace nvfp4 { + +// Hierarchical 1x64 + 1x16 NVFP4 cast. Routes to the fused rowwise+columnwise +// kernel; populates whichever of ``data`` / ``columnwise_data`` are present +// on ``output`` (and the matching scales + window amax buffers). +void quantize_1x64_local_encode(const Tensor& input, const Tensor& noop, Tensor* output, + const QuantizationConfig& quant_config, cudaStream_t stream); + +} // namespace nvfp4 +} // namespace dispatch +} // namespace transformer_engine + +#endif // TRANSFORMER_ENGINE_QUANTIZE_NVFP4_1X64_CUH_ diff --git a/transformer_engine/common/common.h b/transformer_engine/common/common.h index 6e207370dd..c2fe3079ad 100644 --- a/transformer_engine/common/common.h +++ b/transformer_engine/common/common.h @@ -470,6 +470,7 @@ struct QuantizationConfig { bool nvfp4_2d_quantization = false; bool stochastic_rounding = false; bool use_fast_math = false; + bool nvfp4_rowwise_1x64_local_encode = false; static constexpr size_t attr_sizes[] = { sizeof(uint8_t), // force_pow_2_scales @@ -479,7 +480,8 @@ struct QuantizationConfig { sizeof(NVTETensor), // rng_seed and offset sizeof(uint8_t), // nvfp4_2d_quantization sizeof(uint8_t), // stochastic_rounding - sizeof(uint8_t) // use_fast_math + sizeof(uint8_t), // use_fast_math + sizeof(uint8_t) // nvfp4_rowwise_1x64_local_encode }; }; diff --git a/transformer_engine/common/include/transformer_engine/transformer_engine.h b/transformer_engine/common/include/transformer_engine/transformer_engine.h index b7461a85d1..38856808df 100644 --- a/transformer_engine/common/include/transformer_engine/transformer_engine.h +++ b/transformer_engine/common/include/transformer_engine/transformer_engine.h @@ -370,6 +370,9 @@ enum NVTEQuantizationConfigAttribute { * inconsistently between kernels. */ kNVTEQuantizationConfigUseFastMath = 7, + /*! When true, NVFP4 non-RHT rowwise cast uses a per-1x64-K-tile amax to compute + * S_enc = (fp8_max*fp4_max)/tile_amax, instead of a single per-tensor global amax. */ + kNVTEQuantizationConfigNVFP4Rowwise1x64LocalEncode = 8, kNVTEQuantizationConfigNumAttributes }; @@ -1296,6 +1299,13 @@ class QuantizationConfigWrapper { sizeof(val)); } + /*! \brief When true, use per-1x64 (along K) local amax for NVFP4 S_enc in rowwise non-RHT cast. */ + void set_nvfp4_rowwise_1x64_local_encode(bool v) { + const auto val = static_cast(v); + nvte_set_quantization_config_attribute( + config_, kNVTEQuantizationConfigNVFP4Rowwise1x64LocalEncode, &val, sizeof(val)); + } + private: /*! \brief Wrapped NVTEQuantizationConfig. */ NVTEQuantizationConfig config_ = nullptr; diff --git a/transformer_engine/common/transformer_engine.cpp b/transformer_engine/common/transformer_engine.cpp index b97504f2ae..07956e0fd4 100644 --- a/transformer_engine/common/transformer_engine.cpp +++ b/transformer_engine/common/transformer_engine.cpp @@ -1030,6 +1030,9 @@ void nvte_get_quantization_config_attribute(NVTEQuantizationConfig config, case kNVTEQuantizationConfigUseFastMath: bool_to_uint8(config_.use_fast_math, buf); break; + case kNVTEQuantizationConfigNVFP4Rowwise1x64LocalEncode: + bool_to_uint8(config_.nvfp4_rowwise_1x64_local_encode, buf); + break; default: NVTE_ERROR("Unsupported NVTEQuantizationConfigAttribute (got ", static_cast(attr), ")"); } @@ -1085,6 +1088,9 @@ void nvte_set_quantization_config_attribute(NVTEQuantizationConfig config, case kNVTEQuantizationConfigUseFastMath: uint8_to_bool(buf, config_.use_fast_math); break; + case kNVTEQuantizationConfigNVFP4Rowwise1x64LocalEncode: + uint8_to_bool(buf, config_.nvfp4_rowwise_1x64_local_encode); + break; default: NVTE_ERROR("Unsupported NVTEQuantizationConfigAttribute (got ", static_cast(attr), ")"); } diff --git a/transformer_engine/pytorch/csrc/extensions/cast.cpp b/transformer_engine/pytorch/csrc/extensions/cast.cpp index b689a1c1b4..9f45fd7822 100644 --- a/transformer_engine/pytorch/csrc/extensions/cast.cpp +++ b/transformer_engine/pytorch/csrc/extensions/cast.cpp @@ -18,6 +18,7 @@ #include "../extensions.h" #include "common.h" #include "common/util/system.h" +#include "nvfp4_1x64.h" #include "pybind.h" #include "transformer_engine/transformer_engine.h" @@ -1193,10 +1194,8 @@ void split_quantize_nvfp4_impl_helper(const TensorWrapper &input, const size_t num_tensors = input_list.size(); const auto &quantizer = *quantizers.front(); - std::vector nvte_tensor_input_list; std::vector nvte_tensor_output_list; for (size_t i = 0; i < num_tensors; ++i) { - nvte_tensor_input_list.push_back(input_list[i].data()); nvte_tensor_output_list.push_back(output_list[i].data()); } @@ -1210,8 +1209,8 @@ void split_quantize_nvfp4_impl_helper(const TensorWrapper &input, quant_config_list.emplace_back(QuantizationConfigWrapper()); } - // TODO: this is only true because the non-RHT path doesn't have grouped kernels yet, which we can be optimized - // so that we can generate all rng states at once + // Per-tensor quantize: each tensor gets its own kernel launch, so RNG states + // are advanced once per tensor on the host. bool with_bulk_generate_rng_states = false; bool need_stochastic_rounding = quantizer.stochastic_rounding; @@ -1224,25 +1223,39 @@ void split_quantize_nvfp4_impl_helper(const TensorWrapper &input, need_separate_rng_states, quant_config_list, dummy_quant_config_list_colwise); // colwise rng states are not needed in this case - // We need: - // 1. Rowwise amax = amax for input - // 2. Columnwise amax = amax for input too - // Columnwise amax will be filled with a fused D2D copy from rowwise amax - // Note that the multi compute amax API expects rowwise amax pointer to be not null - // So we need to set the pointer accordingly to make colwise-only quantization work - std::vector orig_amax_ptr_list; - for (size_t i = 0; i < num_tensors; i++) { - auto rowwise_amax_ptr = output_list[i].get_amax().data_ptr; - orig_amax_ptr_list.push_back(rowwise_amax_ptr); - auto columnwise_amax_ptr = output_list[i].get_columnwise_amax().data_ptr; - void *amax_ptr = rowwise_amax_ptr != nullptr ? rowwise_amax_ptr : columnwise_amax_ptr; - NVTE_CHECK(amax_ptr != nullptr, "Could not find amax pointer"); - output_list[i].set_amax(amax_ptr, DType::kFloat32, std::vector{1}); + // Optional NVFP4 1x64 local-encode: per-window amax is computed inside the + // kernel (see quantize.cuh), so the global per-tensor amax pre-pass is + // skipped. Otherwise the per-tensor nvte_quantize_v2 path remains identical + // to the non-1x64 baseline. + const bool use_rowwise_1x64 = nvfp4_1x64::local_encode_from_env(); + for (size_t i = 0; i < num_tensors; ++i) { + nvfp4_1x64::config_apply(quant_config_list[i], quantizer.with_2d_quantization, + quantizer.stochastic_rounding, use_rowwise_1x64); } - nvte_group_amax(input.data(), reinterpret_cast(nvte_tensor_output_list.data()), - split_sections.data(), num_tensors, stream); - for (size_t i = 0; i < num_tensors; i++) { - output_list[i].set_amax(orig_amax_ptr_list[i], DType::kFloat32, std::vector{1}); + nvfp4_1x64::require_ok_for_split(quantizer.rowwise_usage, quantizer.columnwise_usage, + quantizer.stochastic_rounding); + + if (!use_rowwise_1x64) { + // We need: + // 1. Rowwise amax = amax for input + // 2. Columnwise amax = amax for input too + // Columnwise amax will be filled with a fused D2D copy from rowwise amax + // Note that the multi compute amax API expects rowwise amax pointer to be not null + // So we need to set the pointer accordingly to make colwise-only quantization work + std::vector orig_amax_ptr_list; + for (size_t i = 0; i < num_tensors; i++) { + auto rowwise_amax_ptr = output_list[i].get_amax().data_ptr; + orig_amax_ptr_list.push_back(rowwise_amax_ptr); + auto columnwise_amax_ptr = output_list[i].get_columnwise_amax().data_ptr; + void *amax_ptr = rowwise_amax_ptr != nullptr ? rowwise_amax_ptr : columnwise_amax_ptr; + NVTE_CHECK(amax_ptr != nullptr, "Could not find amax pointer"); + output_list[i].set_amax(amax_ptr, DType::kFloat32, std::vector{1}); + } + nvte_group_amax(input.data(), reinterpret_cast(nvte_tensor_output_list.data()), + split_sections.data(), num_tensors, stream); + for (size_t i = 0; i < num_tensors; i++) { + output_list[i].set_amax(orig_amax_ptr_list[i], DType::kFloat32, std::vector{1}); + } } // Quantize tensors individually @@ -1391,8 +1404,24 @@ std::vector split_quantize(const at::Tensor &tensor, [](const py::handle &quantizer) -> bool { return detail::IsNVFP4Quantizers(quantizer.ptr()); })) { - allocation_method = AllocationMethod::BULK_NVFP4; - quantization_method = QuantizationMethod::FUSED_NVFP4; + // bulk_allocate_nvfp4_tensors only knows the stock (1,) amax layout; the + // 1x64 kernel needs a per-window amax buffer of shape (M_chunk, N/64). + // When 1x64 is requested, fall back to per-tensor allocation + // (NVFP4Quantizer::create_tensor is 1x64-aware) but keep the fused + // quantize path so split_quantize_nvfp4_impl_helper still drives the + // kernel. Mirror the BULK_NVFP4 case's `% 128 != 0` fallback so + // non-aligned inner dims still go through the unfused per-tensor path + // (split_quantize_nvfp4_impl asserts inner dim is a multiple of 128). + if (nvfp4_1x64::local_encode_from_env()) { + if (!input_shape.empty() && input_shape.back() % 128 != 0) { + quantization_method = QuantizationMethod::UNFUSED; + } else { + quantization_method = QuantizationMethod::FUSED_NVFP4; + } + } else { + allocation_method = AllocationMethod::BULK_NVFP4; + quantization_method = QuantizationMethod::FUSED_NVFP4; + } } } diff --git a/transformer_engine/pytorch/csrc/nvfp4_1x64.h b/transformer_engine/pytorch/csrc/nvfp4_1x64.h new file mode 100644 index 0000000000..e61e4db784 --- /dev/null +++ b/transformer_engine/pytorch/csrc/nvfp4_1x64.h @@ -0,0 +1,62 @@ +/************************************************************************* + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/*! \file nvfp4_1x64.h + * \brief Small helpers for NVFP4 hierarchical 1x64 cast (env + config + + * preconditions), shared between NVFP4Quantizer and split_quantize to + * avoid duplicating policy. The "rowwise" in the env-var name is + * historical -- the kernel now produces rowwise and/or columnwise + * (transposed) output in a single fused pass. + */ +#ifndef TRANSFORMER_ENGINE_PYTORCH_NVFP4_1X64_H_ +#define TRANSFORMER_ENGINE_PYTORCH_NVFP4_1X64_H_ + +#include + +#include "common/util/logging.h" +#include "common/util/system.h" + +namespace transformer_engine::pytorch::nvfp4_1x64 { + +/// Whether the hierarchical 1x64 cast is requested. The env-var name retains +/// its original ROWWISE_ prefix for backward compatibility with users of the +/// rowwise-only kernel that shipped first; both directions are now supported. +[[nodiscard]] inline bool local_encode_from_env() { + return transformer_engine::getenv("NVTE_NVFP4_ROWWISE_1X64_LOCAL_ENCODE", false); +} + +/// Apply 2D mode, SR, and optional 1x64 flag to a quantization config. +inline void config_apply(QuantizationConfigWrapper& cfg, bool nvfp4_2d, bool stochastic_rounding, + bool use_1x64) { + cfg.set_nvfp4_2d_quantization(nvfp4_2d); + cfg.set_stochastic_rounding(stochastic_rounding); + cfg.set_nvfp4_rowwise_1x64_local_encode(use_1x64); +} + +/// Preconditions for \p NVFP4Quantizer::quantize_impl (non-split). +inline void require_ok_for_non_split(bool with_rht, bool /* columnwise */, bool sr) { + if (!local_encode_from_env()) { + return; + } + NVTE_CHECK( + !with_rht, + "NVTE_NVFP4_ROWWISE_1X64_LOCAL_ENCODE=1 requires non-RHT (e.g. NVTE_NVFP4_DISABLE_RHT=1)."); + NVTE_CHECK(!sr, + "NVTE_NVFP4_ROWWISE_1X64_LOCAL_ENCODE=1 is incompatible with stochastic rounding."); +} + +/// Preconditions for \p split_quantize (non-RHT path). +inline void require_ok_for_split(bool /* want_rowwise */, bool /* have_columnwise */, bool sr) { + if (!local_encode_from_env()) { + return; + } + NVTE_CHECK(!sr, + "NVTE_NVFP4_ROWWISE_1X64_LOCAL_ENCODE in split_quantize is incompatible with SR."); +} + +} // namespace transformer_engine::pytorch::nvfp4_1x64 + +#endif // TRANSFORMER_ENGINE_PYTORCH_NVFP4_1X64_H_ diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index b59f3fa3c5..4ba23fee94 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -8,6 +8,7 @@ #include "common.h" #include "common/util/system.h" +#include "nvfp4_1x64.h" #include "pybind.h" #include "torch/torch.h" @@ -1752,6 +1753,24 @@ std::pair NVFP4Quantizer::create_tensor(const std::ve at::Tensor columnwise_data_tensor, columnwise_scale_inv_tensor, amax_columnwise; const auto bit8_tensor_opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); const auto bit32_tensor_opts = at::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA); + + // In hierarchical 1x64 mode the ``amax`` slot is repurposed from a global + // ``(1,)`` scalar to a per-window FP32 buffer: ``(M, N/64)`` for rowwise + // and ``(N, M/64)`` for the transposed columnwise output. The kernel + // requires ``flat_first_dim % 64 == 0`` and ``flat_last_dim % 64 == 0``, + // which the dispatcher re-checks; we duplicate it here so allocation + // failures surface before we ever launch the kernel. + const bool use_1x64 = nvfp4_1x64::local_encode_from_env(); + if (use_1x64) { + NVTE_CHECK(flat_first_dim % 64 == 0, + "NVFP4 1x64 local encode requires the leading flattened dim to be a multiple of " + "64, got ", + flat_first_dim); + NVTE_CHECK(flat_last_dim % 64 == 0, + "NVFP4 1x64 local encode requires the last dim to be a multiple of 64, got ", + flat_last_dim); + } + if (rowwise_usage) { const std::vector scale_inv_shape_int64(rowwise_scale_inv_shape.begin(), rowwise_scale_inv_shape.end()); @@ -1759,7 +1778,13 @@ std::pair NVFP4Quantizer::create_tensor(const std::ve rowwise_scale_inv_tensor = at::empty(scale_inv_shape_int64, bit8_tensor_opts); // hadamard amax kernel will zero out pointer with ZeroAmaxKernel // nvte_compute_amax_with_config will zero out the pointer if needed - amax_rowwise = at::empty({1}, bit32_tensor_opts); + if (use_1x64) { + amax_rowwise = at::empty( + {static_cast(flat_first_dim), static_cast(flat_last_dim) / 64}, + bit32_tensor_opts); + } else { + amax_rowwise = at::empty({1}, bit32_tensor_opts); + } } if (columnwise_usage) { const std::vector scale_inv_shape_int64(columnwise_scale_inv_shape.begin(), @@ -1774,7 +1799,13 @@ std::pair NVFP4Quantizer::create_tensor(const std::ve columnwise_scale_inv_tensor = at::empty(scale_inv_shape_int64, bit8_tensor_opts); // hadamard amax kernel will zero out pointer with ZeroAmaxKernel // nvte_compute_amax_with_config will zero out the pointer if needed - amax_columnwise = at::empty({1}, bit32_tensor_opts); + if (use_1x64) { + amax_columnwise = at::empty( + {static_cast(flat_last_dim), static_cast(flat_first_dim) / 64}, + bit32_tensor_opts); + } else { + amax_columnwise = at::empty({1}, bit32_tensor_opts); + } } // Convert tensors to Python @@ -1847,7 +1878,9 @@ std::pair NVFP4Quantizer::create_tensor(const std::ve out_cpp.set_rowwise_data(rowwise_data_tensor.data_ptr(), DType::kFloat4E2M1, shape); out_cpp.set_rowwise_scale_inv(rowwise_scale_inv_tensor.data_ptr(), DType::kFloat8E4M3, rowwise_scale_inv_shape); - out_cpp.set_amax(amax_rowwise.data_ptr(), DType::kFloat32, std::vector{1}); + const std::vector amax_row_shape = + use_1x64 ? std::vector{flat_first_dim, flat_last_dim / 64} : std::vector{1}; + out_cpp.set_amax(amax_rowwise.data_ptr(), DType::kFloat32, amax_row_shape); } if (columnwise_usage) { // enforce 2D shape to avoid [S, B, H] shape and B and be 1 @@ -1858,8 +1891,9 @@ std::pair NVFP4Quantizer::create_tensor(const std::ve col_data_shape_fp4); out_cpp.set_columnwise_scale_inv(columnwise_scale_inv_tensor.data_ptr(), DType::kFloat8E4M3, columnwise_scale_inv_shape); - out_cpp.set_columnwise_amax(amax_columnwise.data_ptr(), DType::kFloat32, - std::vector{1}); + const std::vector amax_col_shape = + use_1x64 ? std::vector{flat_last_dim, flat_first_dim / 64} : std::vector{1}; + out_cpp.set_columnwise_amax(amax_columnwise.data_ptr(), DType::kFloat32, amax_col_shape); } out_cpp.set_with_gemm_swizzled_scales(with_gemm_swizzled_scales); this->set_quantization_params(&out_cpp); @@ -2227,8 +2261,10 @@ void NVFP4Quantizer::quantize_impl(const TensorWrapper& input, TensorWrapper& ou quant_config.set_noop_tensor(noop_flag->data()); quant_config_columnwise.set_noop_tensor(noop_flag->data()); } - quant_config.set_nvfp4_2d_quantization(this->with_2d_quantization); - quant_config.set_stochastic_rounding(this->stochastic_rounding); + nvfp4_1x64::config_apply(quant_config, this->with_2d_quantization, this->stochastic_rounding, + nvfp4_1x64::local_encode_from_env()); + nvfp4_1x64::require_ok_for_non_split(this->with_rht, this->columnwise_usage, + this->stochastic_rounding); // We only need RHT for columnwise usage. // flat first dim and last dim for multi dimensional input @@ -2304,7 +2340,12 @@ void NVFP4Quantizer::quantize_impl(const TensorWrapper& input, TensorWrapper& ou "Use with_post_rht_amax=true instead."); } } else { // Without RHT - if (compute_amax) { + // The hierarchical 1x64 cast computes per-window amax inside the kernel + // (writing ``(M, N/64)`` / ``(N, M/64)`` buffers into the ``amax`` slots); + // running ``nvte_compute_amax_with_config`` here would overwrite the + // per-window allocation with a single FP32 global amax and clobber the + // shape back to ``(1,)``. Skip the precompute for that path entirely. + if (compute_amax && !nvfp4_1x64::local_encode_from_env()) { // Amax pointers auto rowwise_amax_ptr = out.get_amax().data_ptr; auto columnwise_amax_ptr = out.get_columnwise_amax().data_ptr; diff --git a/transformer_engine/pytorch/custom_recipes/quantization_nvfp4_1x64.py b/transformer_engine/pytorch/custom_recipes/quantization_nvfp4_1x64.py new file mode 100644 index 0000000000..1a554c62f0 --- /dev/null +++ b/transformer_engine/pytorch/custom_recipes/quantization_nvfp4_1x64.py @@ -0,0 +1,233 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Reference implementation for NVFP4 hierarchical 1x64 cast (rowwise + columnwise). + +The hierarchical 1x64 + 1x16 scheme replaces the per-tensor encoding scaling +factor used by stock NVFP4 with a per-1x64-K-window scaling factor; the four +1x16 sub-blocks inside a window share their parent ``S_enc``. The CUDA kernel +that implements this lives in +``transformer_engine/common/cast/nvfp4/quantize_nvfp4_1x64.cu`` and is +dispatched when ``NVTE_NVFP4_ROWWISE_1X64_LOCAL_ENCODE=1``. + +This file mirrors that kernel's arithmetic in pure PyTorch so tests can +compare the kernel's output byte-for-byte against a Python oracle. The +arithmetic ordering and intermediate clamps are chosen to match what the +kernel does: + +* ``S_enc_tile = (FP8_MAX*FP4_MAX) / max(tile_amax, 1e-12)`` clamped to + ``fp32_max``; +* ``s_dec = saturating_cast(vec_max * S_enc_tile / FP4_MAX)`` (the + ``1/FP4_MAX`` is folded into the ``S_enc`` multiplier to match + ``compute_decoding_scaling_factor``); +* ``block_scale = S_enc_tile / fp32(s_dec)`` with an explicit ``s_dec == 0`` + short-circuit to 0 (matches the kernel's ``s_dec_f == 0.f`` branch); +* ``q = round_fp4_satfinite(x_fp32 * block_scale)`` with values clamped to + ``[-FP4_MAX, FP4_MAX]`` before packing. + +For the columnwise (transposed) output the kernel runs the same math along +the original M direction with a 64x1 window; the reference implements this +by simply running the rowwise routine on ``x.T``. The window amax tensor +is exposed alongside ``data`` / ``scale`` so consumers can reconstruct +``S_enc_window`` (the per-block ``s_dec`` alone is not enough information +to dequantize correctly). + +Only the non-RHT, non-2D, non-stochastic-rounding path is supported. +""" + +from __future__ import annotations + +import dataclasses +from typing import Optional, Tuple + +import torch + +from transformer_engine.pytorch.custom_recipes.quantization_nvfp4 import cast_to_fp4x2 + +# Window/block geometry is fixed by the kernel design. +WINDOW_K: int = 64 +BLOCK_K: int = 16 +BLOCKS_PER_WINDOW: int = WINDOW_K // BLOCK_K # 4 + +# E2M1 / E4M3 numeric extrema (matches ``TypeExtrema`` in core_nvfp4.cuh). +FLOAT4_E2M1_MAX: float = 6.0 +FLOAT8_E4M3_MAX: float = 448.0 + +# Matches the kernel's ``fmaxf(tile_amax, 1e-12f)`` clamp guarding the divisor +# of ``compute_global_encode_scaling_factor_FP4``. +_TILE_AMAX_FLOOR: float = 1e-12 + + +@dataclasses.dataclass +class RefNVFP4Tensor1x64: + """Container for the hierarchical 1x64 reference output. + + Mirrors the subset of attributes that the bit-exact test inspects. + + Attributes + ---------- + data: + Packed rowwise FP4 bytes, ``(M, N // 2)`` ``uint8``. + scale: + Per-1x16-block rowwise decode scale (E4M3), ``(M, N // 16)`` + ``float8_e4m3fn``. + window_amax_row: + Per-1x64-window rowwise amax, ``(M, N // 64)`` ``float32``. + ``S_enc_window`` is recoverable from this via + ``compute_global_encode_scaling_factor_FP4(window_amax)``; consumers + need it for correct dequantization (the per-block ``s_dec`` alone + does not contain enough information). + columnwise_data, columnwise_scale, window_amax_col: + Their columnwise (transposed) counterparts. The transposed FP4 data + has shape ``(N, M // 2)`` (matching the production cast+transpose + layout), the columnwise scales ``(N, M // 16)``, and the columnwise + window amax ``(N, M // 64)``. ``None`` if columnwise was not + requested. + """ + + data: Optional[torch.Tensor] = None + scale: Optional[torch.Tensor] = None + window_amax_row: Optional[torch.Tensor] = None + columnwise_data: Optional[torch.Tensor] = None + columnwise_scale: Optional[torch.Tensor] = None + window_amax_col: Optional[torch.Tensor] = None + + +class NVFP4Quantizer1x64Ref: + """Reference implementation of the hierarchical 1x64 cast kernel. + + Constructor takes flags matching the kernel's two output-direction + switches; RHT, 2D scaling, and stochastic rounding are not exposed + because the kernel rejects them at dispatch time and we do not want + the reference to drift away from the kernel's actual capabilities. + """ + + def __init__(self, rowwise: bool = True, columnwise: bool = False) -> None: + if not rowwise and not columnwise: + raise ValueError("At least one of rowwise / columnwise must be True.") + self.rowwise = rowwise + self.columnwise = columnwise + + @staticmethod + def _quantize_2d( + x: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Run the 1x64 reference math on a 2D input along its trailing dim. + + Returns ``(qx, sx, window_amax)`` where ``qx`` is ``(M, N // 2)`` + ``uint8``, ``sx`` is ``(M, N // BLOCK_K)`` ``float8_e4m3fn``, and + ``window_amax`` is ``(M, N // WINDOW_K)`` ``float32``. + + The columnwise pass is implemented by calling this routine on + ``x.T.contiguous()``; both passes therefore share a single + arithmetic path, which is the cleanest way to keep the two + directions consistent with the kernel (the kernel itself shares + its math by re-using ``compute_decoding_scaling_factor`` / + ``compute_global_encode_scaling_factor_FP4`` between passes). + """ + if x.ndim != 2: + raise ValueError(f"NVFP4Quantizer1x64Ref expects a 2D tensor, got {x.ndim}D") + M, N = x.shape + if N % WINDOW_K != 0: + raise ValueError( + f"N={N} must be a multiple of WINDOW_K={WINDOW_K} (kernel hard requirement)" + ) + + device = x.device + fp32_max = torch.tensor(torch.finfo(torch.float32).max, device=device, dtype=torch.float32) + fp4_max = torch.tensor(FLOAT4_E2M1_MAX, device=device, dtype=torch.float32) + fp8_max = torch.tensor(FLOAT8_E4M3_MAX, device=device, dtype=torch.float32) + + Np = N + n_win = Np // WINDOW_K + n_blk = Np // BLOCK_K + + x_fp32 = x.to(torch.float32).contiguous() + x_win = x_fp32.view(M, n_win, WINDOW_K) + x_blk = x_fp32.view(M, n_blk, BLOCK_K) + + # 1x64 tile amax. The kernel applies ``fmaxf(tile_amax, 1e-12f)`` to + # the divisor; do the same here. ``S_enc_tile`` is then computed + # exactly like ``compute_global_encode_scaling_factor_FP4``: divide, + # clamp to fp32_max, and fall back to 1.0 if the divisor or quotient + # is zero (the latter branch is dead given the floor but is mirrored + # for parity with the C++ helper). + tile_amax = torch.amax(torch.abs(x_win), dim=-1, keepdim=True) # (M, n_win, 1) + tile_amax_safe = torch.clamp(tile_amax, min=_TILE_AMAX_FLOOR) + + S_enc_tile = (fp8_max * fp4_max) / tile_amax_safe + S_enc_tile = torch.minimum(S_enc_tile, fp32_max) + S_enc_tile = torch.where( + (tile_amax_safe == 0) | (S_enc_tile == 0), + torch.ones_like(S_enc_tile), + S_enc_tile, + ) + + # Fold ``1 / fp4_max`` into the multiplier the same way the kernel + # does in ``compute_decoding_scaling_factor`` (``S_enc * fp4_max_inv``). + # Keeping the operation order identical is what makes the resulting + # E4M3 scale bit-exact with the kernel. + S_enc_tile_mul_inv6 = S_enc_tile * torch.reciprocal(fp4_max) + + # 1x16 block amax and per-block S_enc broadcast. Each 1x64 window + # spans BLOCKS_PER_WINDOW (=4) consecutive 1x16 sub-blocks, so + # ``repeat_interleave`` along the block axis aligns one S_enc_tile + # to every block inside that window. + vec_max = torch.amax(torch.abs(x_blk), dim=-1, keepdim=True) # (M, n_blk, 1) + S_enc_per_blk = S_enc_tile.repeat_interleave(BLOCKS_PER_WINDOW, dim=1) + S_enc_per_blk_mul = S_enc_tile_mul_inv6.repeat_interleave(BLOCKS_PER_WINDOW, dim=1) + + # decode_scale = saturating_cast(vec_max * S_enc_tile / 6). + # The kernel does not clamp before the cast; we do, because PyTorch's + # ``.to(float8_e4m3fn)`` does not match CUDA's saturating cast for + # values above FP8_MAX. After the explicit clamp the two paths agree. + decode_scale_fp32 = vec_max * S_enc_per_blk_mul + decode_scale_fp32 = torch.minimum(decode_scale_fp32, fp32_max) + decode_scale_fp32 = torch.clamp(decode_scale_fp32, min=-fp8_max, max=fp8_max) + decode_scale_e4m3 = decode_scale_fp32.to(torch.float8_e4m3fn) + decode_scale_back_fp32 = decode_scale_e4m3.to(torch.float32) + + # block_scale = S_enc_tile / s_dec, matching ``__fdiv_rn`` in the + # kernel. All-zero blocks have ``s_dec == 0``, which would yield + # ``+inf`` here and propagate NaN through the downstream multiply + # (``cvt.rn.satfinite.e2m1x4.f32(NaN)`` saturates to +FP4_MAX on + # SM10 -- we do NOT want that). Short-circuit to 0 to mirror the + # kernel's ``s_dec_f == 0.f`` branch. + zero_blk = decode_scale_back_fp32 == 0 + denom = torch.where( + zero_blk, torch.ones_like(decode_scale_back_fp32), decode_scale_back_fp32 + ) + encode_scale = S_enc_per_blk / denom + encode_scale = torch.where(zero_blk, torch.zeros_like(encode_scale), encode_scale) + encode_scale = torch.minimum(encode_scale, fp32_max) + + # Apply scale, clamp to FP4 range, and pack two FP4 values per byte. + scaled_x = x_blk * encode_scale + clipped_x = torch.clamp(scaled_x, -fp4_max, fp4_max).reshape(M, Np) + qx = cast_to_fp4x2(clipped_x).contiguous() # (M, N // 2) + + sx = decode_scale_e4m3.squeeze(-1).contiguous() # (M, n_blk) + + # Per-1x64-window amax, exposed for consumer-side dequantization. + window_amax = tile_amax.squeeze(-1).to(torch.float32).contiguous() # (M, n_win) + + return qx, sx, window_amax + + def quantize(self, tensor: torch.Tensor) -> RefNVFP4Tensor1x64: + """Quantize ``tensor`` and return a ``RefNVFP4Tensor1x64``.""" + out = RefNVFP4Tensor1x64() + if self.rowwise: + qx, sx, win_amax = self._quantize_2d(tensor) + out.data = qx + out.scale = sx + out.window_amax_row = win_amax + if self.columnwise: + # The columnwise output is the rowwise quantization of the + # transpose; both directions share the same math and the same + # ``s_dec``/``block_scale`` chain. + qx_t, sx_t, win_amax_t = self._quantize_2d(tensor.transpose(0, 1).contiguous()) + out.columnwise_data = qx_t + out.columnwise_scale = sx_t + out.window_amax_col = win_amax_t + return out