From e16712ac7a64e81577f3e4bc6a8356861e915cc1 Mon Sep 17 00:00:00 2001 From: Cael Ling Date: Sat, 25 Apr 2026 03:47:10 -0700 Subject: [PATCH 01/10] Add NVFP4 rowwise 1x64 local-encode quantization Signed-off-by: Cael Ling --- reference_hierarchical_nvfp4/__init__.py | 36 +++ .../compare_64x64_global_vs_1x64.py | 113 +++++++ .../fp8_e4m3_utils_np.py | 73 +++++ .../hierarchical_nvfp4_ref.py | 305 ++++++++++++++++++ .../hierarchical_nvfp4_ref_numpy.py | 285 ++++++++++++++++ transformer_engine/common/CMakeLists.txt | 1 + .../common/cast/dispatch/quantize.cuh | 17 +- .../nvfp4/group_quantize_transpose_nvfp4.cuh | 167 +++++----- .../cast/nvfp4/quantize_nvfp4_1x64_rowwise.cu | 161 +++++++++ .../nvfp4/quantize_nvfp4_1x64_rowwise.cuh | 31 ++ transformer_engine/common/common.h | 4 +- .../transformer_engine/transformer_engine.h | 10 + .../common/transformer_engine.cpp | 6 + .../pytorch/csrc/extensions/cast.cpp | 122 ++++--- transformer_engine/pytorch/csrc/nvfp4_1x64.h | 64 ++++ transformer_engine/pytorch/csrc/quantizer.cpp | 6 +- 16 files changed, 1280 insertions(+), 121 deletions(-) create mode 100644 reference_hierarchical_nvfp4/__init__.py create mode 100644 reference_hierarchical_nvfp4/compare_64x64_global_vs_1x64.py create mode 100644 reference_hierarchical_nvfp4/fp8_e4m3_utils_np.py create mode 100644 reference_hierarchical_nvfp4/hierarchical_nvfp4_ref.py create mode 100644 reference_hierarchical_nvfp4/hierarchical_nvfp4_ref_numpy.py create mode 100644 transformer_engine/common/cast/nvfp4/quantize_nvfp4_1x64_rowwise.cu create mode 100644 transformer_engine/common/cast/nvfp4/quantize_nvfp4_1x64_rowwise.cuh create mode 100644 transformer_engine/pytorch/csrc/nvfp4_1x64.h diff --git a/reference_hierarchical_nvfp4/__init__.py b/reference_hierarchical_nvfp4/__init__.py new file mode 100644 index 0000000000..a0b3192cb8 --- /dev/null +++ b/reference_hierarchical_nvfp4/__init__.py @@ -0,0 +1,36 @@ +"""NVFP4 reference: same *recipe* as TE (S_enc, S_dec fp8, bsi, FP4) except +``S_enc = (fp8_max*fp4_max)/amax`` uses each **1x64** window's max instead of +per-tensor global amax. See ``fp8_e4m3_utils_np`` and ``core_nvfp4.cuh``. + +- PyTorch: symbols below (requires ``torch`` + ``numpy`` for E4M3). +- CPU: ``hierarchical_nvfp4_ref_numpy`` (``numpy`` only), or run + ``python reference_hierarchical_nvfp4/hierarchical_nvfp4_ref_numpy.py``. +""" + +from .hierarchical_nvfp4_ref import ( + COARSE, + FINE, + HierarchicalNVFP4Colwise, + HierarchicalNVFP4Rowwise, + dequantize_colwise, + dequantize_rowwise, + fp4_e2m1_grid_torch, + quantize_columnwise_1x64_1x16, + quantize_rowwise_1x64_1x16, + reference_matmul_tn, + roundtrip_error, +) + +__all__ = [ + "COARSE", + "FINE", + "HierarchicalNVFP4Colwise", + "HierarchicalNVFP4Rowwise", + "dequantize_colwise", + "dequantize_rowwise", + "fp4_e2m1_grid_torch", + "quantize_columnwise_1x64_1x16", + "quantize_rowwise_1x64_1x16", + "reference_matmul_tn", + "roundtrip_error", +] diff --git a/reference_hierarchical_nvfp4/compare_64x64_global_vs_1x64.py b/reference_hierarchical_nvfp4/compare_64x64_global_vs_1x64.py new file mode 100644 index 0000000000..78c6054504 --- /dev/null +++ b/reference_hierarchical_nvfp4/compare_64x64_global_vs_1x64.py @@ -0,0 +1,113 @@ +# Compare: TE-style NVFP4 (single S_enc from *global* amax) vs +# reference 1x64 (S_enc from max|x| in each 1x64 K-window = per-row for K=64). +# Matrix shape (M, K) = (64, 64), rowwise along K, 1x16 blocks * 4 per row. +# +# Run (no torch): python3 reference_hierarchical_nvfp4/compare_64x64_global_vs_1x64.py + +from __future__ import annotations + +import os +import sys + +import numpy as np + +# Load sibling ref modules (avoid package __init__ -> torch) +_REF_DIR = os.path.dirname(os.path.abspath(__file__)) +if _REF_DIR not in sys.path: + sys.path.insert(0, _REF_DIR) + +from fp8_e4m3_utils_np import ( + TINY, + compute_S_dec_f32_before_cast_te, + compute_S_enc_from_amax_1x64_like_te, + e4m3_u8_to_f32, + f32_to_e4m3_u8, +) +from hierarchical_nvfp4_ref_numpy import ( + FINE, + FP4_E2M1_GRID, + dequantize_rowwise, + quantize_rowwise_1x64_1x16, +) + +M = 64 +K = 64 + + +def _round_nearest_fp4(x: np.ndarray) -> np.ndarray: + d = np.abs(x[..., None] - FP4_E2M1_GRID) + return FP4_E2M1_GRID[np.argmin(d, axis=-1).astype(np.int64)] + + +def te_nvfp4_rowwise_global_senc( + x: np.ndarray, eps: float = 1e-12 +) -> tuple[np.ndarray, float, np.ndarray, np.ndarray]: + """ + Same math as ref but S_enc = 2688 / global_amax (single float for all blocks). + Returns: x_recon, global_amax, w_pre_fp4, S_dec_u8 (M, n16) per-block fp8 + """ + x = np.asarray(x, np.float32) + m, k = x.shape + g_amax = float(np.max(np.abs(x))) + S_g = compute_S_enc_from_amax_1x64_like_te(g_amax) + n16 = (k + FINE - 1) // FINE + w = np.empty_like(x) + s_dec = np.empty((m, n16), dtype=np.uint8) + for r in range(m): + t16b = 0 + while t16b * FINE < k: + lo, hi = t16b * FINE, min((t16b + 1) * FINE, k) + segx = x[r, lo:hi] + bamax = float(np.max(np.abs(segx))) + if bamax < eps: + bamax = float(eps) + raw = compute_S_dec_f32_before_cast_te(bamax, S_g) + u = f32_to_e4m3_u8(np.array([raw], dtype=np.float32).reshape(1)) + s_dec[r, t16b] = u.reshape(-1)[0] + s_d = max(float(e4m3_u8_to_f32(s_dec[r : r + 1, t16b : t16b + 1]).reshape(-1)[0]), TINY) + bsi = S_g / s_d + w[r, lo:hi] = segx * bsi + t16b += 1 + q = _round_nearest_fp4(w) + t16g = (np.arange(k) // FINE).astype(np.int64) + sde = e4m3_u8_to_f32(s_dec[:, t16g].astype(np.uint8)) + sde = np.maximum(sde, TINY) + x_recon = q * (sde / S_g) + return x_recon.astype(np.float32), g_amax, w, s_dec + + +def main() -> None: + rng = np.random.default_rng(2026) + # "Real" data: not uniform — heavy-tailed + one row scaled up for local/global gap + x = rng.standard_normal((M, K)).astype(np.float32) + x *= 0.35 + x[7, :] *= 4.0 + x[0:8, 12:20] += 0.4 + + x_ref1 = quantize_rowwise_1x64_1x16(x) + recon_1x64 = dequantize_rowwise(x_ref1) + recon_global, g_amax, _, _ = te_nvfp4_rowwise_global_senc(x) + + d = np.abs(recon_1x64 - x) + d2 = np.abs(recon_global - x) + dg = np.abs(recon_1x64 - recon_global) + + print("=== 64x64 数值对比 (rowwise, K=4×16) ===") + print("global_amax =", g_amax) + print("S_enc 现网(全局) = 2688 / global_amax =", compute_S_enc_from_amax_1x64_like_te(g_amax)) + print("---") + print("量化再反归一 vs 原张量: max abs err [1x64 S_enc 参考] :", float(np.max(d))) + print("量化再反归一 vs 原张量: max abs err [全局 S_enc] :", float(np.max(d2))) + print("RMS 误差 vs 原张量 [1x64]:", float(np.sqrt(np.mean(d**2)))) + print("RMS 误差 vs 原张量 [全局]:", float(np.sqrt(np.mean(d2**2)))) + print("---") + print("两种重建之间的 max abs 差 |recon_1x64 - recon_global| :", float(np.max(dg))) + print("RMS( recon_1x64 - recon_global ) :", float(np.sqrt(np.mean(dg**2)))) + fn = float(np.linalg.norm(x, "fro")) + if fn > 0: + print("||x||_F =", fn) + print("Fro 相对: ||recon_1x64 - recon_global||_F / ||x||_F =", float(np.linalg.norm(dg, "fro") / fn)) + + +if __name__ == "__main__": + main() diff --git a/reference_hierarchical_nvfp4/fp8_e4m3_utils_np.py b/reference_hierarchical_nvfp4/fp8_e4m3_utils_np.py new file mode 100644 index 0000000000..46b8f87cfa --- /dev/null +++ b/reference_hierarchical_nvfp4/fp8_e4m3_utils_np.py @@ -0,0 +1,73 @@ +# FP8 E4M3: decode uint8->float32 and encode f32->nearest uint8 (all 256 codes). +# Matches OCP/FP8 E4M3 for reference roundtrip; close to CUDA static_cast(f32). + +from __future__ import annotations + +import numpy as np + +# Match transformer_engine::detail::TypeExtrema (common.h) / nvfp4.cu kernels +FP8_E4M3_FMAX: float = 448.0 +FP4_E2M1_FMAX: float = 6.0 +S_ENC_NUMER: float = FP8_E4M3_FMAX * FP4_E2M1_FMAX # 2688 +FLT_MAX: float = 3.402823466e38 +TINY: float = 1.17549435e-38 + + +def _decode_e4m3_byte(b: int) -> float: + u = b & 0xFF + sign = u >> 7 + exp = (u >> 3) & 0x0F + man = u & 0x07 + if exp == 0: + if man == 0: + return 0.0 + v = (man / 8.0) * (2.0 ** (-6)) + else: + v = (1.0 + man / 8.0) * (2.0 ** (exp - 7)) + return -v if sign else v + + +# Precompute 256 float32 values for all E4M3 codes +_E4M3_TABLE: np.ndarray = np.array( + [_decode_e4m3_byte(i) for i in range(256)], dtype=np.float32 +) + + +def f32_to_e4m3_u8(x: np.ndarray) -> np.ndarray: + """Round each element to nearest fp8e4m3 (by L_inf on 256 codes). x can be any shape.""" + x = np.asarray(x, dtype=np.float32) + flat = x.ravel()[:, None] + d = np.abs(flat - _E4M3_TABLE[None, :]) + out = np.argmin(d, axis=1).astype(np.uint8) + return out.reshape(x.shape) + + +def e4m3_u8_to_f32(b: np.ndarray) -> np.ndarray: + return _E4M3_TABLE[np.asarray(b, dtype=np.int32) & 0xFF].astype(np.float32) + + +def compute_S_enc_from_amax_1x64_like_te(amax: float, eps: float = TINY) -> float: + """ + Same as compute_global_encode_scaling_factor_FP4 in core_nvfp4.cuh, but *amax* is + the 1x64 local max (replaces per-tensor global amax in this ref). + """ + a = float(amax) + if a <= 0.0 or not np.isfinite(a): + return 1.0 + safe = max(a, eps) + g = S_ENC_NUMER / safe + return float(min(g, FLT_MAX)) + + +def compute_S_dec_f32_before_cast_te(block_amax: float, S_enc: float) -> float: + """ + Unquantized S_dec = block_amax * (S_enc / 6) before cast to e4m3 + (compute_decoding_scaling_factor, quantization_SF in core_nvfp4.cuh). + """ + return float(np.float32(block_amax * (S_enc * (1.0 / FP4_E2M1_FMAX)))) + + +def f32_e4m3_f32(x: float) -> float: + """Cast pipeline f32 -> fp8e4m3 -> f32, numpy.""" + u = f32_to_e4m3_u8(np.array([x], dtype=np.float32)) + return float(e4m3_u8_to_f32(u)[0]) diff --git a/reference_hierarchical_nvfp4/hierarchical_nvfp4_ref.py b/reference_hierarchical_nvfp4/hierarchical_nvfp4_ref.py new file mode 100644 index 0000000000..3b92918752 --- /dev/null +++ b/reference_hierarchical_nvfp4/hierarchical_nvfp4_ref.py @@ -0,0 +1,305 @@ +# PyTorch twin of `hierarchical_nvfp4_ref_numpy.py`. +# +# Aligned with TE NVFP4 (core_nvfp4.cuh, quantize_nvfp4.cuh) for: +# S_enc, S_dec = cast_f32(block_amax * S_enc / 6) as fp8e4m3, block_scale_inv = S_enc / f32(S_dec), +# then FP4 from x * bsi. **S_enc** uses **1x64 local amax** instead of per-tensor global amax. +# +# Standalone. CPU twin: hierarchical_nvfp4_ref_numpy.py + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Tuple + +import torch + +import numpy as np + +from .fp8_e4m3_utils_np import TINY, compute_S_dec_f32_before_cast_te +from .fp8_e4m3_utils_np import compute_S_enc_from_amax_1x64_like_te as _s_enc_from_amax_cpu + +try: + from .fp8_e4m3_utils_np import e4m3_u8_to_f32 as _e4m3_u8_f32_np + from .fp8_e4m3_utils_np import f32_to_e4m3_u8 as _f32_e4m3_u8_np +except ImportError: + import os + import sys + + sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) + from fp8_e4m3_utils_np import e4m3_u8_to_f32 as _e4m3_u8_f32_np + from fp8_e4m3_utils_np import f32_to_e4m3_u8 as _f32_e4m3_u8_np + +COARSE = 64 +FINE = 16 + + +def fp4_e2m1_grid_torch(device, dtype) -> torch.Tensor: + vals = ( + 0.0, + 0.5, + 1.0, + 1.5, + 2.0, + 3.0, + 4.0, + 6.0, + -0.0, + -0.5, + -1.0, + -1.5, + -2.0, + -3.0, + -4.0, + -6.0, + ) + return torch.tensor(vals, device=device, dtype=dtype) + + +def _round_to_nearest_fp4(x: torch.Tensor, grid: torch.Tensor) -> torch.Tensor: + diff = (x.unsqueeze(-1) - grid).abs() + return grid[diff.argmin(dim=-1)] + + +def _pack_fp4_along_k(x_fp4: torch.Tensor) -> torch.Tensor: + m, k = x_fp4.shape + grid = fp4_e2m1_grid_torch(x_fp4.device, torch.float32) + diff = (x_fp4.unsqueeze(-1) - grid).abs() + nibble = diff.argmin(dim=-1).to(torch.uint8) + n_pairs = (k + 1) // 2 + out = torch.zeros(m, n_pairs, dtype=torch.uint8, device=x_fp4.device) + for p in range(k // 2): + j = 2 * p + out[:, p] = (nibble[:, j] & 0x0F) | ((nibble[:, j + 1] & 0x0F) << 4) + if k % 2 == 1: + out[:, -1] = nibble[:, -1] & 0x0F + return out + + +def _unpack_fp4_along_k(data_u8: torch.Tensor, k: int) -> torch.Tensor: + m = data_u8.size(0) + grid = fp4_e2m1_grid_torch(data_u8.device, torch.float32) + out = torch.empty(m, k, device=data_u8.device, dtype=grid.dtype) + p = 0 + for j in range(0, k - 1, 2): + b = data_u8[:, p] + p += 1 + out[:, j] = grid[(b & 0x0F).to(torch.long)] + out[:, j + 1] = grid[((b >> 4) & 0x0F).to(torch.long)] + if k % 2 == 1: + b = data_u8[:, p] + out[:, -1] = grid[(b & 0x0F).to(torch.long)] + return out + + +def _pack_fp4_along_m(x_fp4: torch.Tensor) -> torch.Tensor: + m, k = x_fp4.shape + grid = fp4_e2m1_grid_torch(x_fp4.device, torch.float32) + diff = (x_fp4.unsqueeze(-1) - grid).abs() + nibble = diff.argmin(dim=-1).to(torch.uint8) + n_pairs = (m + 1) // 2 + out = torch.zeros(n_pairs, k, dtype=torch.uint8, device=x_fp4.device) + for p in range(m // 2): + r = 2 * p + out[p, :] = (nibble[r, :] & 0x0F) | ((nibble[r + 1, :] & 0x0F) << 4) + if m % 2 == 1: + out[-1, :] = nibble[-1, :] & 0x0F + return out + + +def _unpack_fp4_along_m(data_u8: torch.Tensor, m: int, k: int) -> torch.Tensor: + grid = fp4_e2m1_grid_torch(data_u8.device, torch.float32) + out = torch.empty(m, k, device=data_u8.device, dtype=grid.dtype) + p = 0 + for r in range(0, m - 1, 2): + b = data_u8[p, :] + p += 1 + out[r, :] = grid[(b & 0x0F).to(torch.long)] + out[r + 1, :] = grid[((b >> 4) & 0x0F).to(torch.long)] + if m % 2 == 1: + b = data_u8[p, :] + out[-1, :] = grid[(b & 0x0F).to(torch.long)] + return out + + +@dataclass +class HierarchicalNVFP4Rowwise: + m: int + k: int + data_u8: torch.Tensor + S_enc: torch.Tensor + S_dec_u8: torch.Tensor + amax_64: torch.Tensor + + +@dataclass +class HierarchicalNVFP4Colwise: + m: int + k: int + data_u8: torch.Tensor + S_enc: torch.Tensor + S_dec_u8: torch.Tensor + amax_64: torch.Tensor + + +def _amax_64_k(x: torch.Tensor) -> torch.Tensor: + m, k = int(x.size(0)), int(x.size(1)) + n64 = (k + COARSE - 1) // COARSE + a = x.new_empty(m, n64) + for t64 in range(n64): + lo, hi = t64 * COARSE, min((t64 + 1) * COARSE, k) + a[:, t64] = x[:, lo:hi].abs().max(dim=1).values + return a + + +def quantize_rowwise_1x64_1x16(x: torch.Tensor, eps: float = 1e-12) -> HierarchicalNVFP4Rowwise: + assert x.dim() == 2 + m, k = int(x.size(0)), int(x.size(1)) + device, dtype = x.device, x.dtype + x = x.to(torch.float32) + n16 = (k + FINE - 1) // FINE + amax_64 = _amax_64_k(x) + n64 = amax_64.size(1) + S_enc = torch.empty(m, n64, device=device, dtype=torch.float32) + for ri in range(m): + for t64 in range(n64): + S_enc[ri, t64] = _s_enc_from_amax_cpu(float(amax_64[ri, t64].item())) + S_dec_u8 = torch.empty(m, n16, device=device, dtype=torch.uint8) + w = x.clone() + grid = fp4_e2m1_grid_torch(device, torch.float32) + for row in range(m): + t16b = 0 + while t16b * FINE < k: + lo, hi = t16b * FINE, min((t16b + 1) * FINE, k) + t64 = lo // COARSE + s_e = S_enc[row, t64] + segx = x[row, lo:hi] + bamax = float(segx.abs().max().item()) + if bamax < eps: + bamax = float(eps) + raw = compute_S_dec_f32_before_cast_te(bamax, float(s_e.item())) + u = int( + _f32_e4m3_u8_np(np.array([raw], dtype=np.float32).reshape(1))[ + 0 + ] + ) + S_dec_u8[row, t16b] = u + s_dec_f = max(float(_e4m3_u8_f32_np(np.array([u], dtype=np.uint8))[0]), TINY) + bsi = float(s_e.item()) / s_dec_f + w[row, lo:hi] = segx * bsi + t16b += 1 + q = _round_to_nearest_fp4(w, grid) + return HierarchicalNVFP4Rowwise(m, k, _pack_fp4_along_k(q), S_enc, S_dec_u8, amax_64) + + +def dequantize_rowwise(p: HierarchicalNVFP4Rowwise) -> torch.Tensor: + m, k, device = p.m, p.k, p.data_u8.device + q = _unpack_fp4_along_k(p.data_u8, k) + j16 = (torch.arange(k, device=device) // FINE).long() + j64 = (torch.arange(k, device=device) // COARSE).long() + sdec = torch.from_numpy( + _e4m3_u8_f32_np(p.S_dec_u8[:, j16].cpu().numpy().astype(np.uint8)) + ).to(device=device, dtype=torch.float32) + senc = p.S_enc[:, j64] + sdec = torch.clamp(sdec, min=TINY) + return (q * (sdec / senc)).to(torch.float32) + + +def _amax_64_m(x: torch.Tensor) -> torch.Tensor: + m, k = int(x.size(0)), int(x.size(1)) + n64 = (m + COARSE - 1) // COARSE + a = x.new_empty(n64, k) + for t64 in range(n64): + lo, hi = t64 * COARSE, min((t64 + 1) * COARSE, m) + a[t64, :] = x[lo:hi, :].abs().max(dim=0).values + return a + + +def quantize_columnwise_1x64_1x16( + x: torch.Tensor, eps: float = 1e-12 +) -> HierarchicalNVFP4Colwise: + assert x.dim() == 2 + m, k = int(x.size(0)), int(x.size(1)) + device = x.device + x = x.to(torch.float32) + n16 = (m + FINE - 1) // FINE + amax_64 = _amax_64_m(x) + n64 = amax_64.size(0) + S_enc = torch.empty(n64, k, device=device, dtype=torch.float32) + for t64 in range(n64): + for col in range(k): + S_enc[t64, col] = _s_enc_from_amax_cpu(float(amax_64[t64, col].item())) + S_dec_u8 = torch.empty(n16, k, device=device, dtype=torch.uint8) + w = x.clone() + grid = fp4_e2m1_grid_torch(device, torch.float32) + for col in range(k): + t16b = 0 + while t16b * FINE < m: + lo, hi = t16b * FINE, min((t16b + 1) * FINE, m) + t64 = lo // COARSE + s_e = S_enc[t64, col] + segx = x[lo:hi, col] + bamax = float(segx.abs().max().item()) + if bamax < eps: + bamax = float(eps) + raw = compute_S_dec_f32_before_cast_te(bamax, float(s_e.item())) + u = int( + _f32_e4m3_u8_np(np.array([raw], dtype=np.float32).reshape(1))[ + 0 + ] + ) + S_dec_u8[t16b, col] = u + s_dec_f = max(float(_e4m3_u8_f32_np(np.array([u], dtype=np.uint8))[0]), TINY) + bsi = float(s_e.item()) / s_dec_f + w[lo:hi, col] = segx * bsi + t16b += 1 + q = _round_to_nearest_fp4(w, grid) + return HierarchicalNVFP4Colwise(m, k, _pack_fp4_along_m(q), S_enc, S_dec_u8, amax_64) + + +def dequantize_colwise(p: HierarchicalNVFP4Colwise) -> torch.Tensor: + m, k, device = p.m, p.k, p.data_u8.device + q = _unpack_fp4_along_m(p.data_u8, m, k) + r16 = (torch.arange(m, device=device) // FINE).long() + r64 = (torch.arange(m, device=device) // COARSE).long() + sdec = torch.from_numpy( + _e4m3_u8_f32_np(p.S_dec_u8[r16, :].cpu().numpy().astype(np.uint8)) + ).to(device=device, dtype=torch.float32) + senc = p.S_enc[r64, :] + sdec = torch.clamp(sdec, min=TINY) + return (q * (sdec / senc)).to(torch.float32) + + +def reference_matmul_tn( + a_rows: HierarchicalNVFP4Rowwise, + b_cols: HierarchicalNVFP4Colwise, +) -> torch.Tensor: + return dequantize_rowwise(a_rows) @ dequantize_colwise(b_cols).T + + +def roundtrip_error( + x: torch.Tensor, mode: str +) -> Tuple[torch.Tensor, torch.Tensor]: + if mode == "rowwise": + p = quantize_rowwise_1x64_1x16(x) + y = dequantize_rowwise(p) + elif mode == "colwise": + p = quantize_columnwise_1x64_1x16(x) + y = dequantize_colwise(p) + else: + raise ValueError("mode is rowwise or colwise") + e = (x.to(torch.float32) - y).abs().max() + return y, e + + +if __name__ == "__main__": + torch.manual_seed(0) + m, n, kdim = 4, 5, 128 + a = torch.randn(m, kdim) + b = torch.randn(n, kdim) + pa, pb = quantize_rowwise_1x64_1x16(a), quantize_columnwise_1x64_1x16(b) + y = reference_matmul_tn(pa, pb) + y_ref = a @ b.T + err = (y - y_ref).abs().max() + print("matmul ref max abs err (via dequant):", float(err)) + _, e = roundtrip_error(a, "rowwise") + print("rowwise roundtrip max abs err:", float(e)) diff --git a/reference_hierarchical_nvfp4/hierarchical_nvfp4_ref_numpy.py b/reference_hierarchical_nvfp4/hierarchical_nvfp4_ref_numpy.py new file mode 100644 index 0000000000..9d3dc43830 --- /dev/null +++ b/reference_hierarchical_nvfp4/hierarchical_nvfp4_ref_numpy.py @@ -0,0 +1,285 @@ +# CPU reference aligned with TE NVFP4 (core_nvfp4.cuh + quantize_nvfp4.cuh path) except: +# S_enc = (fp8_max * fp4_max) / amax uses **1x64 local amax** instead of per-tensor global amax. +# +# Stages: S_enc(1x64) -> S_dec = cast_fp8( block_amax_1x16 * S_enc / 6 ) -> bsi = S_enc / f32(S_dec) +# -> t = x * bsi -> FP4 E2M1 (nearest). Dequant: x_hat = q_fp4 * (f32(S_dec) / S_enc). +# +# Requires: numpy. Run: python TransformerEngine/reference_hierarchical_nvfp4/hierarchical_nvfp4_ref_numpy.py + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Tuple + +import numpy as np + +try: + from .fp8_e4m3_utils_np import ( + TINY, + compute_S_dec_f32_before_cast_te, + compute_S_enc_from_amax_1x64_like_te, + e4m3_u8_to_f32, + f32_to_e4m3_u8, + ) +except ImportError: + import os + import sys + + sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) + from fp8_e4m3_utils_np import ( + TINY, + compute_S_dec_f32_before_cast_te, + compute_S_enc_from_amax_1x64_like_te, + e4m3_u8_to_f32, + f32_to_e4m3_u8, + ) + +COARSE = 64 +FINE = 16 + +FP4_E2M1_GRID = np.array( + [ + 0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0, + -0.0, -0.5, -1.0, -1.5, -2.0, -3.0, -4.0, -6.0, + ], + dtype=np.float32, +) + + +def _round_to_nearest_fp4(x: np.ndarray) -> np.ndarray: + d = np.abs(x[..., None] - FP4_E2M1_GRID) + return FP4_E2M1_GRID[np.argmin(d, axis=-1).astype(np.int64)] + + +def _pack_fp4_along_k(x_fp4: np.ndarray) -> np.ndarray: + m, k = x_fp4.shape + d = np.abs(x_fp4[..., None] - FP4_E2M1_GRID) + nibble = np.argmin(d, axis=-1).astype(np.uint8) + n_pairs = (k + 1) // 2 + out = np.zeros((m, n_pairs), dtype=np.uint8) + for p in range(k // 2): + j = 2 * p + out[:, p] = (nibble[:, j] & 0x0F) | ((nibble[:, j + 1] & 0x0F) << 4) + if k % 2 == 1: + out[:, -1] = nibble[:, -1] & 0x0F + return out + + +def _unpack_fp4_along_k(data_u8: np.ndarray, k: int) -> np.ndarray: + m = data_u8.shape[0] + out = np.empty((m, k), dtype=np.float32) + p = 0 + for j in range(0, k - 1, 2): + b = data_u8[:, p] + p += 1 + out[:, j] = FP4_E2M1_GRID[(b & 0x0F).astype(np.int64)] + out[:, j + 1] = FP4_E2M1_GRID[((b >> 4) & 0x0F).astype(np.int64)] + if k % 2 == 1: + b = data_u8[:, p] + out[:, -1] = FP4_E2M1_GRID[(b & 0x0F).astype(np.int64)] + return out + + +def _pack_fp4_along_m(x_fp4: np.ndarray) -> np.ndarray: + m, k = x_fp4.shape + d = np.abs(x_fp4[..., None] - FP4_E2M1_GRID) + nibble = np.argmin(d, axis=-1).astype(np.uint8) + n_pairs = (m + 1) // 2 + out = np.zeros((n_pairs, k), dtype=np.uint8) + for p in range(m // 2): + r = 2 * p + out[p, :] = (nibble[r, :] & 0x0F) | ((nibble[r + 1, :] & 0x0F) << 4) + if m % 2 == 1: + out[-1, :] = nibble[-1, :] & 0x0F + return out + + +def _unpack_fp4_along_m(data_u8: np.ndarray, m: int, k: int) -> np.ndarray: + out = np.empty((m, k), dtype=np.float32) + p = 0 + for r in range(0, m - 1, 2): + b = data_u8[p, :] + p += 1 + out[r, :] = FP4_E2M1_GRID[(b & 0x0F).astype(np.int64)] + out[r + 1, :] = FP4_E2M1_GRID[((b >> 4) & 0x0F).astype(np.int64)] + if m % 2 == 1: + b = data_u8[p, :] + out[-1, :] = FP4_E2M1_GRID[(b & 0x0F).astype(np.int64)] + return out + + +@dataclass +class HierarchicalNVFP4RowwiseNp: + m: int + k: int + data_u8: np.ndarray + S_enc: np.ndarray + S_dec_u8: np.ndarray + amax_64: np.ndarray + + +@dataclass +class HierarchicalNVFP4ColwiseNp: + m: int + k: int + data_u8: np.ndarray + S_enc: np.ndarray + S_dec_u8: np.ndarray + amax_64: np.ndarray + + +def _amax_64_k(x: np.ndarray, m: int, k: int) -> np.ndarray: + n64 = (k + COARSE - 1) // COARSE + out = np.empty((m, n64), dtype=np.float32) + for row in range(m): + for t64 in range(n64): + lo, hi = t64 * COARSE, min((t64 + 1) * COARSE, k) + seg = x[row, lo:hi] + out[row, t64] = float(np.max(np.abs(seg))) if seg.size else 0.0 + return out + + +def quantize_rowwise_1x64_1x16(x: np.ndarray, eps: float = 1e-12) -> HierarchicalNVFP4RowwiseNp: + x = np.asarray(x, dtype=np.float32) + assert x.ndim == 2 + m, k = int(x.shape[0]), int(x.shape[1]) + n16 = (k + FINE - 1) // FINE + amax_64 = _amax_64_k(x, m, k) + n64 = amax_64.shape[1] + S_enc = np.empty((m, n64), dtype=np.float32) + for ri in range(m): + for t64 in range(n64): + S_enc[ri, t64] = compute_S_enc_from_amax_1x64_like_te(float(amax_64[ri, t64])) + S_dec_u8 = np.empty((m, n16), dtype=np.uint8) + w = np.empty_like(x) + for row in range(m): + t16b = 0 + while t16b * FINE < k: + lo, hi = t16b * FINE, min((t16b + 1) * FINE, k) + t64 = lo // COARSE + S = float(S_enc[row, t64]) + segx = x[row, lo:hi] + bamax = float(np.max(np.abs(segx))) + if bamax < eps: + bamax = float(eps) + raw = compute_S_dec_f32_before_cast_te(bamax, S) + u = f32_to_e4m3_u8(np.array(raw, dtype=np.float32).reshape(1)) + S_dec_u8[row, t16b] = u.ravel()[0] + s_dec_f = max(float(e4m3_u8_to_f32(S_dec_u8[row, t16b : t16b + 1])[0]), TINY) + bsi = S / s_dec_f + w[row, lo:hi] = segx * bsi + t16b += 1 + q = _round_to_nearest_fp4(w) + return HierarchicalNVFP4RowwiseNp( + m, k, _pack_fp4_along_k(q), S_enc, S_dec_u8, amax_64 + ) + + +def dequantize_rowwise(p: HierarchicalNVFP4RowwiseNp) -> np.ndarray: + m, k = p.m, p.k + q = _unpack_fp4_along_k(p.data_u8, k) + t16 = (np.arange(k) // FINE).astype(np.int64) + t64 = (np.arange(k) // COARSE).astype(np.int64) + sdec = e4m3_u8_to_f32(p.S_dec_u8[:, t16].astype(np.uint8)) + senc = p.S_enc[:, t64] + sdec = np.maximum(sdec, TINY) + return (q * (sdec / senc)).astype(np.float32) + + +def _amax_64_m(x: np.ndarray, m: int, k: int) -> np.ndarray: + n64 = (m + COARSE - 1) // COARSE + out = np.empty((n64, k), dtype=np.float32) + for col in range(k): + for t64 in range(n64): + lo, hi = t64 * COARSE, min((t64 + 1) * COARSE, m) + seg = x[lo:hi, col] + out[t64, col] = float(np.max(np.abs(seg))) if seg.size else 0.0 + return out + + +def quantize_columnwise_1x64_1x16( + x: np.ndarray, eps: float = 1e-12 +) -> HierarchicalNVFP4ColwiseNp: + x = np.asarray(x, dtype=np.float32) + assert x.ndim == 2 + m, k = int(x.shape[0]), int(x.shape[1]) + n16 = (m + FINE - 1) // FINE + amax_64 = _amax_64_m(x, m, k) + n64 = amax_64.shape[0] + S_enc = np.empty((n64, k), dtype=np.float32) + for t64 in range(n64): + for col in range(k): + S_enc[t64, col] = compute_S_enc_from_amax_1x64_like_te(float(amax_64[t64, col])) + S_dec_u8 = np.empty((n16, k), dtype=np.uint8) + w = np.empty_like(x) + for col in range(k): + t16b = 0 + while t16b * FINE < m: + lo, hi = t16b * FINE, min((t16b + 1) * FINE, m) + t64 = lo // COARSE + S = float(S_enc[t64, col]) + segx = x[lo:hi, col] + bamax = float(np.max(np.abs(segx))) + if bamax < eps: + bamax = float(eps) + raw = compute_S_dec_f32_before_cast_te(bamax, S) + u = f32_to_e4m3_u8(np.array(raw, dtype=np.float32).reshape(1)) + S_dec_u8[t16b, col] = u.ravel()[0] + s_dec_f = max(float(e4m3_u8_to_f32(S_dec_u8[t16b : t16b + 1, col : col + 1])[0, 0]), TINY) + bsi = S / s_dec_f + w[lo:hi, col] = segx * bsi + t16b += 1 + q = _round_to_nearest_fp4(w) + return HierarchicalNVFP4ColwiseNp( + m, k, _pack_fp4_along_m(q), S_enc, S_dec_u8, amax_64 + ) + + +def dequantize_colwise(p: HierarchicalNVFP4ColwiseNp) -> np.ndarray: + m, k = p.m, p.k + q = _unpack_fp4_along_m(p.data_u8, m, k) + r16 = (np.arange(m) // FINE).astype(np.int64) + r64 = (np.arange(m) // COARSE).astype(np.int64) + sdec = e4m3_u8_to_f32(p.S_dec_u8[r16, :].astype(np.uint8)) + senc = p.S_enc[r64, :] + sdec = np.maximum(sdec, TINY) + return (q * (sdec / senc)).astype(np.float32) + + +def reference_matmul_tn( + a_rows: HierarchicalNVFP4RowwiseNp, + b_cols: HierarchicalNVFP4ColwiseNp, +) -> np.ndarray: + return dequantize_rowwise(a_rows) @ dequantize_colwise(b_cols).T + + +def roundtrip_error(x: np.ndarray, mode: str) -> Tuple[np.ndarray, float]: + if mode == "rowwise": + p = quantize_rowwise_1x64_1x16(x) + y = dequantize_rowwise(p) + elif mode == "colwise": + p = quantize_columnwise_1x64_1x16(x) + y = dequantize_colwise(p) + else: + raise ValueError("mode is rowwise or colwise") + e = float(np.max(np.abs(x.astype(np.float32) - y))) + return y, e + + +def _self_test() -> None: + rng = np.random.default_rng(0) + m, n, kdim = 4, 5, 128 + a = rng.standard_normal((m, kdim)).astype(np.float32) + b = rng.standard_normal((n, kdim)).astype(np.float32) + pa, pb = quantize_rowwise_1x64_1x16(a), quantize_columnwise_1x64_1x16(b) + y = reference_matmul_tn(pa, pb) + a @ b.T + # loose bound (aggressive low precision) + assert np.isfinite(y).all() + assert dequantize_rowwise(pa).shape == a.shape + assert dequantize_colwise(pb).shape == b.shape + print("hierarchical_nvfp4_ref_numpy (TE1x64 path): _self_test OK") + + +if __name__ == "__main__": + _self_test() diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index 3f684adbb4..bf17d5556f 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -219,6 +219,7 @@ list(APPEND transformer_engine_cuda_arch_specific_sources multi_tensor/compute_scale.cu recipe/mxfp8_scaling.cu recipe/nvfp4.cu + cast/nvfp4/quantize_nvfp4_1x64_rowwise.cu transpose/quantize_transpose_square_blockwise.cu transpose/quantize_transpose_vector_blockwise_fp4.cu) diff --git a/transformer_engine/common/cast/dispatch/quantize.cuh b/transformer_engine/common/cast/dispatch/quantize.cuh index 8d985f64f3..dccf5a2091 100644 --- a/transformer_engine/common/cast/dispatch/quantize.cuh +++ b/transformer_engine/common/cast/dispatch/quantize.cuh @@ -22,6 +22,7 @@ #include "../mxfp8/quantize_mxfp8.cuh" #include "../nvfp4/group_quantize_transpose_nvfp4.cuh" #include "../nvfp4/quantize_nvfp4.cuh" +#include "../nvfp4/quantize_nvfp4_1x64_rowwise.cuh" #include "../nvfp4/quantize_transpose_nvfp4.cuh" namespace transformer_engine { @@ -104,8 +105,17 @@ void quantize_fwd_helper(const NVTETensor input, NVTETensor output, bool use_optimized_kernel = (dtype == DType::kBFloat16) && (rows % 32 == 0) && (cols % 32 == 0) && output_tensor->has_data(); - // Launch NVFP4 quantize kernel - if (use_optimized_kernel) { + // Per-1x64-K-tile S_enc (non-RHT rowwise only). Fused window kernel; no columnwise / no GEMM. + if (quant_config_cpp.nvfp4_rowwise_1x64_local_encode) { + NVTE_CHECK(!output_tensor->has_columnwise_data(), + "NVFP4 rowwise 1x64 local encode does not support columnwise (transposed) output."); + NVTE_CHECK(!quant_config_cpp.nvfp4_2d_quantization, + "NVFP4 rowwise 1x64 local encode is incompatible with 2D block scaling."); + NVTE_CHECK(!quant_config_cpp.stochastic_rounding, + "NVFP4 rowwise 1x64 local encode does not support stochastic rounding yet."); + nvfp4::quantize_rowwise_1x64_local_encode(*input_tensor, *noop_tensor, output_tensor, + quant_config_cpp, stream); + } else if (use_optimized_kernel) { if (quant_config_cpp.nvfp4_2d_quantization) { nvfp4::quantize_transpose( *input_tensor, noop_tensor, output_tensor, &quant_config_cpp, stream); @@ -236,6 +246,9 @@ void quantize_bwd_helper(const NVTETensor grad, const NVTETensor input, NVTETens CheckInputTensor(*grad_tensor, "input"); CheckOutputTensor(*output_tensor, "output", false); + NVTE_CHECK(!quant_config_cpp.nvfp4_rowwise_1x64_local_encode, + "NVFP4 rowwise 1x64 local encode is not implemented for backward quantization yet."); + // Choose kernel int32_t rows = grad_tensor->flat_first_dim(); int32_t cols = grad_tensor->flat_last_dim(); diff --git a/transformer_engine/common/cast/nvfp4/group_quantize_transpose_nvfp4.cuh b/transformer_engine/common/cast/nvfp4/group_quantize_transpose_nvfp4.cuh index a2f3dac15a..c827f466ed 100644 --- a/transformer_engine/common/cast/nvfp4/group_quantize_transpose_nvfp4.cuh +++ b/transformer_engine/common/cast/nvfp4/group_quantize_transpose_nvfp4.cuh @@ -48,6 +48,8 @@ struct MultiAmaxCastTransposeFusionArgs { void *output_colwise_scale_inv_list[kMaxTensorsPerKernel]; // (Unused for rowwise only scaling) output scale stride for colwise scaling int output_colwise_scale_stride[kMaxTensorsPerKernel]; + // (Unused for rowwise only scaling) output data stride for colwise data (in fp4e2m1x2 units) + int output_colwise_data_stride[kMaxTensorsPerKernel]; // Prefix sum (with leading zero) of split_sections of each tensor of input int split_sections_range[kMaxTensorsPerKernel + 1]; // Number of tensors (splits) being processed by kernel @@ -88,7 +90,7 @@ __device__ __forceinline__ void UpdateEncodeDecodeScaleFP32(float *amax_ptr, flo float *s_dec_ptr) { float s_env_value = (amax_ptr == nullptr) ? 1.0f : compute_global_encode_scaling_factor_FP4(*amax_ptr); - float s_dec_value = 1.0 / s_env_value; + float s_dec_value = 1.0f / s_env_value; *s_enc_ptr = s_env_value; *s_dec_ptr = s_dec_value; return; @@ -202,23 +204,15 @@ __global__ void __launch_bounds__(THREADS_NUM) const size_t block_offset_Y = blockIdx.y * CHUNK_DIM_Y; const size_t block_offset_X = blockIdx.x * CHUNK_DIM_X; - // TODO(zhongbo): add back when transpose is supported - // const size_t block_offset_Y_t = blockIdx.x * CHUNK_DIM_X; - // const size_t block_offset_X_t = blockIdx.y * CHUNK_DIM_Y; - const size_t chunk_rows = rows - block_offset_Y; const size_t scales_block_offset_Y_rowwise = blockIdx.y * CHUNK_DIM_Y; const size_t scales_block_offset_X_rowwise = blockIdx.x * SCALES_PER_CHUNK_X; - // TODO(zhongbo): add back when transpose is supported - // const size_t scales_block_offset_Y_t = blockIdx.x * CHUNK_DIM_X; - // const size_t scales_block_offset_X_t = blockIdx.y * SCALES_PER_CHUNK_Y; const size_t tid_Y_rowwise = threadIdx.x / THREADS_X_ROWWISE; const size_t tid_X_rowwise = threadIdx.x % THREADS_X_ROWWISE; const size_t tid_X_colwise = threadIdx.x; const size_t tid_Y_t = tid_X_colwise; - // const size_t tid_X_t = 0; const size_t thread_offset_Y_rowwise = tid_Y_rowwise; const size_t thread_offset_X_rowwise = tid_X_rowwise * SCALE_DIM; @@ -232,17 +226,11 @@ __global__ void __launch_bounds__(THREADS_NUM) const size_t scales_offset_Y_rowwise = scales_block_offset_Y_rowwise + tid_Y_rowwise; const size_t scales_offset_X_rowwise = scales_block_offset_X_rowwise + tid_X_rowwise; - // TODO(zhongbo): add back when transpose is supported - // const size_t scales_offset_Y_t = scales_block_offset_Y_t + tid_Y_t; - // const size_t scales_offset_X_t = scales_block_offset_X_t; const size_t SFs_per_row = cols / SCALE_DIM; const bool rowwise_scale_is_within_bounds_X = scales_offset_X_rowwise < SFs_per_row; - // TODO(zhongbo): add back when transpose is supported - // const bool colwise_scale_is_within_bounds_Y = scales_offset_Y_t < cols; - // Helps resolving bank conflicts in shmem const int thread_lane = threadIdx.x % THREADS_PER_WARP; const int bank_group = thread_lane / THREADS_PER_BANK; @@ -283,12 +271,14 @@ __global__ void __launch_bounds__(THREADS_NUM) const bool is_master_thread = (threadIdx.x == 0); - // TODO (zhongbo): finish this float *amax_rowwise_ptr = nullptr; float *amax_colwise_ptr = nullptr; nvfp4_scale_t *split_rowwise_scale_ptr = nullptr; + fp4e2m1x2 *split_colwise_data_ptr = nullptr; + nvfp4_scale_t *split_colwise_scale_ptr = nullptr; + int split_colwise_data_stride = 0; + int split_colwise_scale_stride = 0; - // suppose the amax is fixed for the current 128x128 tile (need 128 padding) bool need_update_tensor_id = true; int tensor_id = GetTensorIdAndBoundary(&kernel_args, block_offset_Y, block_offset_Y + CHUNK_DIM_Y, &need_update_tensor_id); @@ -297,12 +287,20 @@ __global__ void __launch_bounds__(THREADS_NUM) amax_rowwise_ptr = reinterpret_cast(kernel_args.rowwise_amax_list[tensor_id]); split_rowwise_scale_ptr = reinterpret_cast(kernel_args.output_rowwise_scale_inv_list[tensor_id]); + if constexpr (RETURN_TRANSPOSE) { + amax_colwise_ptr = reinterpret_cast(kernel_args.colwise_amax_list[tensor_id]); + split_colwise_data_ptr = + reinterpret_cast(kernel_args.output_colwise_data_list[tensor_id]); + split_colwise_scale_ptr = + reinterpret_cast(kernel_args.output_colwise_scale_inv_list[tensor_id]); + split_colwise_data_stride = kernel_args.output_colwise_data_stride[tensor_id]; + split_colwise_scale_stride = kernel_args.output_colwise_scale_stride[tensor_id]; + } float S_enc_rowwise = 1.0f; float S_dec_rowwise = 1.0f; UpdateEncodeDecodeScaleFP32(amax_rowwise_ptr, &S_enc_rowwise, &S_dec_rowwise); - // TODO (zhongbo): colwise scaling disabled for now because of transpose float S_enc_colwise = 1.0f; float S_dec_colwise = 1.0f; if (amax_colwise_ptr != nullptr) { @@ -345,8 +343,21 @@ __global__ void __launch_bounds__(THREADS_NUM) UpdateEncodeDecodeScaleFP32(amax_rowwise_ptr, &S_enc_rowwise, &S_dec_rowwise); split_rowwise_scale_ptr = reinterpret_cast(kernel_args.output_rowwise_scale_inv_list[tensor_id]); - // TODO (zhongbo): colwise scaling disabled for now because of transpose - // Skip fetching colwise amax pointer and scaling factor updates + if constexpr (RETURN_TRANSPOSE) { + amax_colwise_ptr = reinterpret_cast(kernel_args.colwise_amax_list[tensor_id]); + if (amax_colwise_ptr != nullptr) { + UpdateEncodeDecodeScaleFP32(amax_colwise_ptr, &S_enc_colwise, &S_dec_colwise); + } else { + S_enc_colwise = S_enc_rowwise; + S_dec_colwise = S_dec_rowwise; + } + split_colwise_data_ptr = + reinterpret_cast(kernel_args.output_colwise_data_list[tensor_id]); + split_colwise_scale_ptr = reinterpret_cast( + kernel_args.output_colwise_scale_inv_list[tensor_id]); + split_colwise_data_stride = kernel_args.output_colwise_data_stride[tensor_id]; + split_colwise_scale_stride = kernel_args.output_colwise_scale_stride[tensor_id]; + } } } @@ -439,6 +450,23 @@ __global__ void __launch_bounds__(THREADS_NUM) tid_Y_t * SCALES_PER_CHUNK_Y + stage * ITERATIONS_TRANSPOSE + it; out_colwise_scales_sh[scale_idx_sh] = S_dec_b_fp8; + // Write colwise scale directly to per-split global buffer (streaming store) + if (split_colwise_scale_ptr != nullptr) { + const size_t global_row = block_offset_Y + stage_offset_Y + it * SCALE_DIM; + const size_t col_idx = block_offset_X + threadIdx.x; + const bool within_split = (global_row >= split_start) && (global_row < split_end); + if (!col_out_of_bounds_colwise && within_split) { + const size_t local_block = (global_row - split_start) / SCALE_DIM; + nvfp4_scale_t *scale_dst = + split_colwise_scale_ptr + col_idx * split_colwise_scale_stride + local_block; + asm volatile( + "st.global.cs.u8 [%0], %1;\n" + : + : "l"(scale_dst), "r"(static_cast(S_dec_b_fp8)) + : "memory"); + } + } + // Compute "correct" per-block encoding scaling factor constexpr float float_max = detail::TypeExtrema::max; const float block_scale_inverse = fminf( @@ -484,6 +512,34 @@ __global__ void __launch_bounds__(THREADS_NUM) out_t_data_sh_as_uint32_t[group] = val[0]; // idx1 = (group + 0) % 2; out_t_data_sh_as_uint32_t[(group + 1) & 1] = val[1]; // idx2 = (group + 1) % 2; } + + // Write colwise quantized data from shmem to per-split global buffers. + // Issued here (right after colwise quantize) so that the strided global + // stores overlap with the rowwise compute that follows. Each thread only + // reads its own 16-byte shmem region, so no __syncthreads is needed. + if (split_colwise_data_ptr != nullptr) { + const size_t global_stage_row = block_offset_Y + stage_offset_Y; + const size_t dst_row_t = block_offset_X + threadIdx.x; + const bool within_bounds = (dst_row_t < cols) && (global_stage_row < rows) && + (global_stage_row >= split_start) && + (global_stage_row < split_end); + if (within_bounds) { + const size_t local_stage_row = global_stage_row - split_start; + const size_t dst_byte_col = local_stage_row / 2; + fp4e2m1x2 *dst = + split_colwise_data_ptr + dst_row_t * split_colwise_data_stride + dst_byte_col; + const fp4e2m1x2 *src = + &out_t_data_sh[buff_offset_out_t + threadIdx.x * BUFF_OUT_T_DIM_X]; + // Use cache-streaming store to avoid L2 pollution from strided writes + const uint4 src_val = *reinterpret_cast(src); + asm volatile( + "st.global.cs.v4.u32 [%0], {%1, %2, %3, %4};\n" + : + : "l"(reinterpret_cast(dst)), + "r"(src_val.x), "r"(src_val.y), "r"(src_val.z), "r"(src_val.w) + : "memory"); + } + } } // ROWWISE scaling @@ -615,12 +671,6 @@ __global__ void __launch_bounds__(THREADS_NUM) const bool rowwise_scale_is_within_bounds_Y = (stage_rowwise_scales_offset_Y + it * THREADS_Y_ROWWISE) < chunk_rows; - // TODO(zhongbo): depending on input padding multiple (whether 128 or 64), use either scale_ptr or split_rowwise_scale_ptr - // const size_t scale_idx_global = scales_offset_Y * scale_stride + scales_offset_X; - // if (rowwise_scale_is_within_bounds_X && rowwise_scale_is_within_bounds_Y) { - // scales_ptr[scale_idx_global] = S_dec_b_fp8; - // } - // Map to local split coordinates const size_t split_rows = split_end - split_start; const size_t local_scale_row = scales_offset_Y - split_start; @@ -688,45 +738,15 @@ __global__ void __launch_bounds__(THREADS_NUM) const size_t global_offset_Y = block_offset_Y + stage_offset_Y; const size_t global_offset_X = block_offset_X; - // TODO(zhongbo): add back when transpose is supported - // const size_t global_offset_Y_t = block_offset_Y_t; - // const size_t global_offset_X_t = block_offset_X_t + stage_offset_Y; - ptx::cp_async_bulk_tensor_2d_shared_to_global( reinterpret_cast(&tensor_map_output), global_offset_X, global_offset_Y, reinterpret_cast(&out_data_sh[buff_offset_out])); - // TODO(zhongbo): add back when transpose is supported - // if constexpr (RETURN_TRANSPOSE) { - // ptx::cp_async_bulk_tensor_2d_shared_to_global( - // reinterpret_cast(&tensor_map_output_t), global_offset_X_t, - // global_offset_Y_t, reinterpret_cast(&out_t_data_sh[buff_offset_out_t])); - // } - // Create a "bulk async-group" out of the previous bulk copy operation. ptx::cp_async_bulk_commit_group(); } - } // end of stages - // TODO(zhongbo): add back when transpose is supported - // Vectorized store scaling factors through SHMEM - // if (RETURN_TRANSPOSE && colwise_scale_is_within_bounds_Y) { - // using ScalesVec = Vec; - // const size_t scale_idx_sh = tid_Y_t * SCALES_PER_CHUNK_Y; - // ScalesVec &scales_vec = *reinterpret_cast(&out_colwise_scales_sh[scale_idx_sh]); - // const size_t scale_idx_global = scales_offset_Y_t * scale_stride_t + scales_offset_X_t; - // const size_t count = // number of scales in Y dimension of this chunk - // (chunk_rows >= CHUNK_DIM_Y) ? SCALES_PER_CHUNK_Y : (chunk_rows / SCALE_DIM); - // nvfp4_scale_t *dst = &scales_t_ptr[scale_idx_global]; - // constexpr size_t vec_bytes = SCALES_PER_CHUNK_Y * sizeof(nvfp4_scale_t); - // if (count == SCALES_PER_CHUNK_Y && (reinterpret_cast(dst) % vec_bytes == 0)) { - // // Fast path: vectorized store when destination is properly aligned - // scales_vec.store_to(dst); - // } else { - // // Safe path: element-wise store for tails or unaligned destinations - // scales_vec.store_to_elts(dst, 0, count); - // } - // } + } // end of stages destroy_barriers(mbar, is_master_thread); #else @@ -753,22 +773,16 @@ void group_quantize_transpose(const Tensor &input, const Tensor *noop, "Number of tensors should be less than or equal to ", kMaxTensorsPerKernel); Tensor *output = nullptr; - // loop over the list to find the first non-empty tensor + // loop over the list to find the first non-empty tensor with actual data for (size_t i = 0; i < num_tensors; ++i) { - if (output_list[i]->has_data()) { + if (output_list[i]->has_data() && output_list[i]->data.dptr != nullptr) { output = output_list[i]; break; } } - NVTE_CHECK(output != nullptr, "No output tensor found."); - // also check that the output has not null data pointer - NVTE_CHECK(output->data.dptr != nullptr, "Output data pointer is null."); + NVTE_CHECK(output != nullptr, "No output tensor found with non-null data pointer."); - // If transposed output is allocated, return the transposed data. Otherwise, it's not necesary to - // return the transposed data. bool return_transpose = output->has_columnwise_data(); - // forbid return transpose for now because group quantize transpose is not supported yet - NVTE_CHECK(!return_transpose, "Return transpose is not supported for group quantize transpose."); // output_List is contiguous in memory, so take the first tensor as the contiguous output auto output_contiguous = output->data; @@ -803,10 +817,20 @@ void group_quantize_transpose(const Tensor &input, const Tensor *noop, reinterpret_cast(output_list[i]->amax.dptr); kernel_args.output_rowwise_scale_inv_list[kernel_args.num_tensors] = reinterpret_cast(output_list[i]->scale_inv.dptr); - // kernel_args.split_sections[kernel_args.num_tensors] = split_sections[i]; + if (return_transpose) { + kernel_args.colwise_amax_list[kernel_args.num_tensors] = + reinterpret_cast(output_list[i]->columnwise_amax.dptr); + kernel_args.output_colwise_data_list[kernel_args.num_tensors] = + reinterpret_cast(output_list[i]->columnwise_data.dptr); + kernel_args.output_colwise_scale_inv_list[kernel_args.num_tensors] = + reinterpret_cast(output_list[i]->columnwise_scale_inv.dptr); + kernel_args.output_colwise_data_stride[kernel_args.num_tensors] = + static_cast(output_list[i]->columnwise_data.shape[1] / 2); + kernel_args.output_colwise_scale_stride[kernel_args.num_tensors] = + static_cast(output_list[i]->columnwise_scale_inv.shape[1]); + } kernel_args.split_sections_range[kernel_args.num_tensors + 1] = kernel_args.split_sections_range[kernel_args.num_tensors] + split_sections[i]; - // check overflow NVTE_CHECK(kernel_args.split_sections_range[kernel_args.num_tensors + 1] >= 0, "split_sections_range overflow the int32_t"); kernel_args.num_tensors++; @@ -822,8 +846,6 @@ void group_quantize_transpose(const Tensor &input, const Tensor *noop, // for the colwise scaling, scaling factor stride is different for each tensor because of transpose // since transpose puts token dimension splits in the last dimension of the tensor const size_t scale_stride = output->scale_inv.shape[1]; - // const size_t scale_stride_transpose = - // return_transpose ? output->columnwise_scale_inv.shape[1] : 0; nvfp4_scale_t *const scales_ptr = reinterpret_cast(output->scale_inv.dptr); @@ -844,17 +866,12 @@ void group_quantize_transpose(const Tensor &input, const Tensor *noop, alignas(64) CUtensorMap tensor_map_input{}; alignas(64) CUtensorMap tensor_map_output{}; - // alignas(64) CUtensorMap tensor_map_output_transpose{}; create_2D_tensor_map(tensor_map_input, input.data, rows, cols, BUFF_DIM_Y, BUFF_DIM_X, cols, 0, sizeof(IType) * 8); create_2D_tensor_map(tensor_map_output, output_contiguous, rows, cols, BUFF_DIM_Y, BUFF_DIM_X, cols, 0, 4); - // if (return_transpose) { - // create_2D_tensor_map(tensor_map_output_transpose, output->columnwise_data, cols, rows, - // BUFF_DIM_X, BUFF_DIM_Y, rows, 0, 4); - // } constexpr size_t buff_elems = BUFF_DIM_Y * BUFF_DIM_X; constexpr size_t buff_elems_total = BUFFS_NUM * buff_elems; constexpr size_t buff_size_aligned_in = diff --git a/transformer_engine/common/cast/nvfp4/quantize_nvfp4_1x64_rowwise.cu b/transformer_engine/common/cast/nvfp4/quantize_nvfp4_1x64_rowwise.cu new file mode 100644 index 0000000000..72dde17e7a --- /dev/null +++ b/transformer_engine/common/cast/nvfp4/quantize_nvfp4_1x64_rowwise.cu @@ -0,0 +1,161 @@ +/************************************************************************* + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include "common/common.h" +#include "common/cast/nvfp4/quantize_nvfp4_1x64_rowwise.cuh" +#include "common/cast/nvfp4/core_nvfp4.cuh" +#include "common/util/ptx.cuh" +#include "common/utils.cuh" + +namespace transformer_engine { +namespace dispatch { +namespace nvfp4 { +namespace { + +#if FP4_TYPE_SUPPORTED + +using ptx::FPx2; +using quantization_SF::compute_decoding_scaling_factor; +using core::compute_global_encode_scaling_factor_FP4; + +// One CUDA block = one 1x64 K-tile in (row, K-window) layout. +// Threads reduce |x| over the tile, then S_enc = TE global formula on tile amax; +// 1x16 blocks share that S_enc. +template +__global__ void __launch_bounds__(64) nvfp4_rowwise_1x64_per_tile( + const IType* __restrict__ in, const size_t rows, const size_t cols, const int ld_row_elts, + uint8_t* __restrict__ out_data, // raw fp4 bytes (same layout as other NVFP4 rowwise) + fp8e4m3* __restrict__ row_scales, const size_t scale_stride, float* __restrict__ amax_global, + const float* __restrict__ noop) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + if (noop != nullptr && noop[0] == 1.0f) { + return; + } + + const int w = static_cast(blockIdx.x); + const int r = static_cast(blockIdx.y); + const int c0 = w * 64; + if (r >= static_cast(rows) || c0 >= static_cast(cols)) { + return; + } + + const int win_len = min(64, static_cast(cols) - c0); + + __shared__ float sm[64]; + for (int i = threadIdx.x; i < 64; i += blockDim.x) { + if (i < win_len) { + sm[i] = fabsf(static_cast(in[static_cast(r) * ld_row_elts + c0 + i])); + } else { + sm[i] = 0.f; + } + } + __syncthreads(); + + if (threadIdx.x == 0) { + float wmx = 0.f; + for (int i = 0; i < 64; i++) { + wmx = fmaxf(wmx, sm[i]); + } + sm[0] = wmx; + } + __syncthreads(); + + const float tile_amax = sm[0]; + const float S_enc_tile = compute_global_encode_scaling_factor_FP4(fmaxf(tile_amax, 1e-12f)); + + if (amax_global != nullptr) { + atomicMaxFloat(amax_global, tile_amax); + } + + if (threadIdx.x < 4) { + const int b = static_cast(threadIdx.x); + const int cs = c0 + b * 16; + if (b * 16 < win_len && cs + 16 <= static_cast(cols)) { + float bmx = 0.f; + float vals[16]; + for (int e = 0; e < 16; e++) { + const float v = static_cast(in[static_cast(r) * ld_row_elts + cs + e]); + vals[e] = v; + bmx = fmaxf(bmx, fabsf(v)); + } + const fp8e4m3 s_dec = compute_decoding_scaling_factor(bmx, S_enc_tile); + const float block_scale = __fdiv_rn(S_enc_tile, static_cast(s_dec)); + + const int c16 = cs / 16; + row_scales[static_cast(r) * scale_stride + static_cast(c16)] = s_dec; + + using IType2 = FPx2; + const size_t row_bytes = static_cast(cols) / 2; // fp4 packed: cols/2 bytes per row + uint8_t* row_out = out_data + static_cast(r) * row_bytes; + for (int q = 0; q < 4; q++) { + const int e0 = q * 4; + IType2 in01{static_cast(vals[e0]), static_cast(vals[e0 + 1])}; + IType2 in23{static_cast(vals[e0 + 2]), static_cast(vals[e0 + 3])}; + fp4e2m1x4 qu{}; + ptx::mul_cvt_4x(qu, in01, in23, block_scale); + *reinterpret_cast(row_out + static_cast(cs / 2) + static_cast(2 * q)) = + qu; + } + } + } +#endif +} + +#endif // FP4_TYPE_SUPPORTED + +} // namespace + +void quantize_rowwise_1x64_local_encode(const Tensor& input, const Tensor& noop, Tensor* output, + const QuantizationConfig& /* quant_config */, cudaStream_t stream) { +#if FP4_TYPE_SUPPORTED + CheckNoopTensor(noop, "cast_noop"); + NVTE_CHECK(input.has_data(), "NVFP4 1x64: input has no data."); + NVTE_CHECK(output->has_data(), "NVFP4 1x64: output has no data."); + NVTE_CHECK(!output->has_columnwise_data(), + "NVFP4 rowwise 1x64: columnwise (transpose) path is not supported (no RHT / no GEMM)."); + NVTE_CHECK(output->scale_inv.dptr != nullptr, "NVFP4 1x64: rowwise scale_inv must be allocated."); + NVTE_CHECK(!output->with_gemm_swizzled_scales, "NVFP4 1x64: expect compact (non-gemm) scales."); + NVTE_CHECK(output->amax.dptr != nullptr, "NVFP4 1x64: rowwise amax buffer is required."); + + const size_t rows = input.flat_first_dim(); + const size_t cols = input.flat_last_dim(); + if (rows == 0 || cols == 0) { + return; + } + NVTE_CHECK(cols % 16 == 0, "NVFP4 1x64: K must be a multiple of 16 (1x16 block size), got: ", + cols); + const size_t n_win = (cols + 63) / 64; + + uint8_t* out_ptr = reinterpret_cast(output->data.dptr); + fp8e4m3* scales = reinterpret_cast(output->scale_inv.dptr); + const size_t s_stride = output->scale_inv.shape.size() > 1 ? output->scale_inv.shape[1] : 1; + float* amax = reinterpret_cast(output->amax.dptr); + NVTE_CHECK_CUDA(cudaMemsetAsync(amax, 0, sizeof(float), stream)); + + dim3 grid(static_cast(n_win), static_cast(rows), 1); + constexpr int kBlock = 64; + + TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( + input.dtype(), IType, { + const IType* in_t = reinterpret_cast(input.data.dptr); + nvfp4_rowwise_1x64_per_tile<<>>( + in_t, rows, cols, static_cast(cols), out_ptr, scales, s_stride, amax, + reinterpret_cast(noop.data.dptr)); + NVTE_CHECK_CUDA(cudaGetLastError()); + }); + +#else + (void)input; + (void)noop; + (void)output; + (void)stream; + NVTE_ERROR("FP4 support requires CUDA 12.8+, but compile-time CUDA version is ", CUDA_VERSION); +#endif +} + +} // namespace nvfp4 +} // namespace dispatch +} // namespace transformer_engine diff --git a/transformer_engine/common/cast/nvfp4/quantize_nvfp4_1x64_rowwise.cuh b/transformer_engine/common/cast/nvfp4/quantize_nvfp4_1x64_rowwise.cuh new file mode 100644 index 0000000000..42a55dfcb7 --- /dev/null +++ b/transformer_engine/common/cast/nvfp4/quantize_nvfp4_1x64_rowwise.cuh @@ -0,0 +1,31 @@ +/************************************************************************* + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/*! \file quantize_nvfp4_1x64_rowwise.cuh + * \brief NVFP4 rowwise cast with per-1x64-K-tile S_enc (non-RHT path; no columnwise / GEMM). + */ +#ifndef TRANSFORMER_ENGINE_QUANTIZE_NVFP4_1X64_ROWWISE_CUH_ +#define TRANSFORMER_ENGINE_QUANTIZE_NVFP4_1X64_ROWWISE_CUH_ + +#include + +#include "../../common.h" + +namespace transformer_engine { +namespace dispatch { +namespace nvfp4 { + +// Same TE NVFP4 math as quantize_transpose / vector_blockwise, but +// S_enc = (fp8_max*fp4_max)/max|x| over the current 1x64 K-tile in each row +// (per row, for each 64-stride K window). +void quantize_rowwise_1x64_local_encode(const Tensor& input, const Tensor& noop, Tensor* output, + const QuantizationConfig& quant_config, cudaStream_t stream); + +} // namespace nvfp4 +} // namespace dispatch +} // namespace transformer_engine + +#endif // TRANSFORMER_ENGINE_QUANTIZE_NVFP4_1X64_ROWWISE_CUH_ diff --git a/transformer_engine/common/common.h b/transformer_engine/common/common.h index 6e207370dd..c2fe3079ad 100644 --- a/transformer_engine/common/common.h +++ b/transformer_engine/common/common.h @@ -470,6 +470,7 @@ struct QuantizationConfig { bool nvfp4_2d_quantization = false; bool stochastic_rounding = false; bool use_fast_math = false; + bool nvfp4_rowwise_1x64_local_encode = false; static constexpr size_t attr_sizes[] = { sizeof(uint8_t), // force_pow_2_scales @@ -479,7 +480,8 @@ struct QuantizationConfig { sizeof(NVTETensor), // rng_seed and offset sizeof(uint8_t), // nvfp4_2d_quantization sizeof(uint8_t), // stochastic_rounding - sizeof(uint8_t) // use_fast_math + sizeof(uint8_t), // use_fast_math + sizeof(uint8_t) // nvfp4_rowwise_1x64_local_encode }; }; diff --git a/transformer_engine/common/include/transformer_engine/transformer_engine.h b/transformer_engine/common/include/transformer_engine/transformer_engine.h index b7461a85d1..38856808df 100644 --- a/transformer_engine/common/include/transformer_engine/transformer_engine.h +++ b/transformer_engine/common/include/transformer_engine/transformer_engine.h @@ -370,6 +370,9 @@ enum NVTEQuantizationConfigAttribute { * inconsistently between kernels. */ kNVTEQuantizationConfigUseFastMath = 7, + /*! When true, NVFP4 non-RHT rowwise cast uses a per-1x64-K-tile amax to compute + * S_enc = (fp8_max*fp4_max)/tile_amax, instead of a single per-tensor global amax. */ + kNVTEQuantizationConfigNVFP4Rowwise1x64LocalEncode = 8, kNVTEQuantizationConfigNumAttributes }; @@ -1296,6 +1299,13 @@ class QuantizationConfigWrapper { sizeof(val)); } + /*! \brief When true, use per-1x64 (along K) local amax for NVFP4 S_enc in rowwise non-RHT cast. */ + void set_nvfp4_rowwise_1x64_local_encode(bool v) { + const auto val = static_cast(v); + nvte_set_quantization_config_attribute( + config_, kNVTEQuantizationConfigNVFP4Rowwise1x64LocalEncode, &val, sizeof(val)); + } + private: /*! \brief Wrapped NVTEQuantizationConfig. */ NVTEQuantizationConfig config_ = nullptr; diff --git a/transformer_engine/common/transformer_engine.cpp b/transformer_engine/common/transformer_engine.cpp index b97504f2ae..07956e0fd4 100644 --- a/transformer_engine/common/transformer_engine.cpp +++ b/transformer_engine/common/transformer_engine.cpp @@ -1030,6 +1030,9 @@ void nvte_get_quantization_config_attribute(NVTEQuantizationConfig config, case kNVTEQuantizationConfigUseFastMath: bool_to_uint8(config_.use_fast_math, buf); break; + case kNVTEQuantizationConfigNVFP4Rowwise1x64LocalEncode: + bool_to_uint8(config_.nvfp4_rowwise_1x64_local_encode, buf); + break; default: NVTE_ERROR("Unsupported NVTEQuantizationConfigAttribute (got ", static_cast(attr), ")"); } @@ -1085,6 +1088,9 @@ void nvte_set_quantization_config_attribute(NVTEQuantizationConfig config, case kNVTEQuantizationConfigUseFastMath: uint8_to_bool(buf, config_.use_fast_math); break; + case kNVTEQuantizationConfigNVFP4Rowwise1x64LocalEncode: + uint8_to_bool(buf, config_.nvfp4_rowwise_1x64_local_encode); + break; default: NVTE_ERROR("Unsupported NVTEQuantizationConfigAttribute (got ", static_cast(attr), ")"); } diff --git a/transformer_engine/pytorch/csrc/extensions/cast.cpp b/transformer_engine/pytorch/csrc/extensions/cast.cpp index b689a1c1b4..2ce423b358 100644 --- a/transformer_engine/pytorch/csrc/extensions/cast.cpp +++ b/transformer_engine/pytorch/csrc/extensions/cast.cpp @@ -18,6 +18,7 @@ #include "../extensions.h" #include "common.h" #include "common/util/system.h" +#include "nvfp4_1x64.h" #include "pybind.h" #include "transformer_engine/transformer_engine.h" @@ -1184,6 +1185,39 @@ void split_quantize_nvfp4_impl_with_rht_helper(const TensorWrapper &input, } } +/// Non-RHT split: grouped amax over splits, then optional D2D copy rowwise amax -> columnwise amax. +static void split_nvfp4_non_rht_run_grouped_amax( + const TensorWrapper &input, std::vector &output_list, + std::vector &nvte_tensor_output_list, const std::vector &split_sections, + size_t num_tensors, bool copy_colwise_amax_from_rowwise, cudaStream_t stream) { + std::vector orig_amax_ptr_list; + orig_amax_ptr_list.reserve(num_tensors); + for (size_t i = 0; i < num_tensors; i++) { + auto rowwise_amax_ptr = output_list[i].get_amax().data_ptr; + orig_amax_ptr_list.push_back(rowwise_amax_ptr); + auto columnwise_amax_ptr = output_list[i].get_columnwise_amax().data_ptr; + void *amax_ptr = rowwise_amax_ptr != nullptr ? rowwise_amax_ptr : columnwise_amax_ptr; + NVTE_CHECK(amax_ptr != nullptr, "Could not find amax pointer"); + output_list[i].set_amax(amax_ptr, DType::kFloat32, std::vector{1}); + } + nvte_group_amax(input.data(), reinterpret_cast(nvte_tensor_output_list.data()), + split_sections.data(), num_tensors, stream); + for (size_t i = 0; i < num_tensors; i++) { + output_list[i].set_amax(orig_amax_ptr_list[i], DType::kFloat32, std::vector{1}); + } + if (copy_colwise_amax_from_rowwise) { + for (size_t i = 0; i < num_tensors; i++) { + auto rowwise_amax_ptr = output_list[i].get_amax().data_ptr; + auto colwise_amax_ptr = output_list[i].get_columnwise_amax().data_ptr; + if (rowwise_amax_ptr != nullptr && colwise_amax_ptr != nullptr && + rowwise_amax_ptr != colwise_amax_ptr) { + NVTE_CHECK_CUDA(cudaMemcpyAsync(colwise_amax_ptr, rowwise_amax_ptr, sizeof(float), + cudaMemcpyDeviceToDevice, stream)); + } + } + } +} + void split_quantize_nvfp4_impl_helper(const TensorWrapper &input, const std::vector &input_list, std::vector &output_list, @@ -1193,65 +1227,71 @@ void split_quantize_nvfp4_impl_helper(const TensorWrapper &input, const size_t num_tensors = input_list.size(); const auto &quantizer = *quantizers.front(); - std::vector nvte_tensor_input_list; std::vector nvte_tensor_output_list; for (size_t i = 0; i < num_tensors; ++i) { - nvte_tensor_input_list.push_back(input_list[i].data()); nvte_tensor_output_list.push_back(output_list[i].data()); } - // In this case without RHT, the rowwise and colwise quantization are fused - // we don't need separate rng states for rowwise and colwise - bool need_separate_rng_states = false; - // Objects for TE C API std::vector quant_config_list; for (size_t i = 0; i < num_tensors; ++i) { quant_config_list.emplace_back(QuantizationConfigWrapper()); } - // TODO: this is only true because the non-RHT path doesn't have grouped kernels yet, which we can be optimized - // so that we can generate all rng states at once - bool with_bulk_generate_rng_states = false; - + bool with_bulk_generate_rng_states = true; + bool need_separate_rng_states = false; bool need_stochastic_rounding = quantizer.stochastic_rounding; - - // place holder for colwise rng states, which are not needed in this case - std::vector dummy_quant_config_list_colwise; - + std::vector quant_config_list_colwise; + for (size_t i = 0; i < num_tensors; ++i) { + quant_config_list_colwise.emplace_back(QuantizationConfigWrapper()); + } auto stochastic_rng_state_resources = setup_stochastic_rounding_rng_states_helper( num_tensors, need_stochastic_rounding, with_bulk_generate_rng_states, - need_separate_rng_states, quant_config_list, - dummy_quant_config_list_colwise); // colwise rng states are not needed in this case - - // We need: - // 1. Rowwise amax = amax for input - // 2. Columnwise amax = amax for input too - // Columnwise amax will be filled with a fused D2D copy from rowwise amax - // Note that the multi compute amax API expects rowwise amax pointer to be not null - // So we need to set the pointer accordingly to make colwise-only quantization work - std::vector orig_amax_ptr_list; - for (size_t i = 0; i < num_tensors; i++) { - auto rowwise_amax_ptr = output_list[i].get_amax().data_ptr; - orig_amax_ptr_list.push_back(rowwise_amax_ptr); - auto columnwise_amax_ptr = output_list[i].get_columnwise_amax().data_ptr; - void *amax_ptr = rowwise_amax_ptr != nullptr ? rowwise_amax_ptr : columnwise_amax_ptr; - NVTE_CHECK(amax_ptr != nullptr, "Could not find amax pointer"); - output_list[i].set_amax(amax_ptr, DType::kFloat32, std::vector{1}); + need_separate_rng_states, quant_config_list, quant_config_list_colwise); + + const auto use_fast_math = transformer_engine::getenv("NVTE_USE_FAST_MATH"); + if (use_fast_math) { + for (auto &config : quant_config_list) { + config.set_use_fast_math(true); + } } - nvte_group_amax(input.data(), reinterpret_cast(nvte_tensor_output_list.data()), - split_sections.data(), num_tensors, stream); - for (size_t i = 0; i < num_tensors; i++) { - output_list[i].set_amax(orig_amax_ptr_list[i], DType::kFloat32, std::vector{1}); + + // 1x64: per-tensor nvte_quantize_v2 (see quantize.cuh), not the grouped amax+kernel. + const bool use_rowwise_1x64 = nvfp4_1x64::local_encode_from_env(); + for (size_t i = 0; i < num_tensors; ++i) { + nvfp4_1x64::config_apply(quant_config_list[i], quantizer.with_2d_quantization, + quantizer.stochastic_rounding, use_rowwise_1x64); } + nvfp4_1x64::require_ok_for_split(quantizer.rowwise_usage, quantizer.columnwise_usage, + quantizer.stochastic_rounding); - // Quantize tensors individually - for (size_t i = 0; i < num_tensors; i++) { - // skip this round if input is empty - if (input_list[i].numel() == 0) { - continue; + if (!use_rowwise_1x64) { + split_nvfp4_non_rht_run_grouped_amax( + input, output_list, nvte_tensor_output_list, split_sections, num_tensors, + quantizer.rowwise_usage && quantizer.columnwise_usage, stream); + } + + if (use_rowwise_1x64 && quantizer.rowwise_usage) { + for (size_t i = 0; i < num_tensors; i++) { + if (input_list[i].numel() == 0) { + continue; + } + nvte_quantize_v2(input_list[i].data(), output_list[i].data(), quant_config_list[i], stream); + } + } else if (quantizer.rowwise_usage) { + // Grouped rowwise (+ columnwise if output tensors carry colwise buffers) + // in a single kernel launch. When the output has columnwise data, the kernel + // template parameter RETURN_TRANSPOSE=true enables the colwise write-back path. + nvte_group_nvfp4_quantize_with_amax(input.data(), nvte_tensor_output_list.data(), + split_sections.data(), num_tensors, quant_config_list[0], + stream); + } else if (quantizer.columnwise_usage) { + // Colwise-only: the grouped kernel requires a contiguous rowwise output + // buffer for TMA, so fall back to per-tensor quantization. + for (size_t i = 0; i < num_tensors; i++) { + if (input_list[i].numel() == 0) continue; + nvte_quantize_v2(input_list[i].data(), output_list[i].data(), quant_config_list[i], stream); } - nvte_quantize_v2(input_list[i].data(), output_list[i].data(), quant_config_list[i], stream); } } diff --git a/transformer_engine/pytorch/csrc/nvfp4_1x64.h b/transformer_engine/pytorch/csrc/nvfp4_1x64.h new file mode 100644 index 0000000000..389241af2b --- /dev/null +++ b/transformer_engine/pytorch/csrc/nvfp4_1x64.h @@ -0,0 +1,64 @@ +/************************************************************************* + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/*! \file nvfp4_1x64.h + * \brief Small helpers for NVFP4 per-1x64-K S_enc (env + config + preconditions), shared + * between NVFP4Quantizer and split_quantize to avoid duplicating policy. + */ +#ifndef TRANSFORMER_ENGINE_PYTORCH_NVFP4_1X64_H_ +#define TRANSFORMER_ENGINE_PYTORCH_NVFP4_1X64_H_ + +#include "common/util/logging.h" +#include "common/util/system.h" +#include + +namespace transformer_engine::pytorch::nvfp4_1x64 { + +/// Whether rowwise 1x64 local encode is requested (TE-wide env, same for single-tensor and split). +[[nodiscard]] inline bool local_encode_from_env() { + return transformer_engine::getenv("NVTE_NVFP4_ROWWISE_1X64_LOCAL_ENCODE", false); +} + +/// Apply 2D mode, SR, and optional 1x64 flag to a quantization config (mirrors what NVFP4 needs). +inline void config_apply(QuantizationConfigWrapper& cfg, bool nvfp4_2d, bool stochastic_rounding, + bool use_rowwise_1x64) { + cfg.set_nvfp4_2d_quantization(nvfp4_2d); + cfg.set_stochastic_rounding(stochastic_rounding); + cfg.set_nvfp4_rowwise_1x64_local_encode(use_rowwise_1x64); +} + +/// Preconditions for \p NVFP4Quantizer::quantize_impl (non-split). +inline void require_ok_for_non_split(bool with_rht, bool columnwise, bool sr) { + if (!local_encode_from_env()) { + return; + } + NVTE_CHECK( + !with_rht, + "NVTE_NVFP4_ROWWISE_1X64_LOCAL_ENCODE=1 requires non-RHT (e.g. NVTE_NVFP4_DISABLE_RHT=1)."); + NVTE_CHECK( + !columnwise, + "NVTE_NVFP4_ROWWISE_1X64_LOCAL_ENCODE=1 supports rowwise-only NVFP4 output."); + NVTE_CHECK(!sr, "NVTE_NVFP4_ROWWISE_1X64_LOCAL_ENCODE=1 is incompatible with stochastic rounding."); +} + +/// Preconditions for \p split_quantize (non-RHT path). +inline void require_ok_for_split(bool want_rowwise, bool have_columnwise, bool sr) { + if (!local_encode_from_env()) { + return; + } + NVTE_CHECK( + want_rowwise, + "NVTE_NVFP4_ROWWISE_1X64_LOCAL_ENCODE in split_quantize requires rowwise output."); + NVTE_CHECK(!have_columnwise, + "NVTE_NVFP4_ROWWISE_1X64_LOCAL_ENCODE in split_quantize does not support columnwise " + "output."); + NVTE_CHECK( + !sr, "NVTE_NVFP4_ROWWISE_1X64_LOCAL_ENCODE in split_quantize is incompatible with SR."); +} + +} // namespace transformer_engine::pytorch::nvfp4_1x64 + +#endif // TRANSFORMER_ENGINE_PYTORCH_NVFP4_1X64_H_ diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index b59f3fa3c5..070549442b 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -8,6 +8,7 @@ #include "common.h" #include "common/util/system.h" +#include "nvfp4_1x64.h" #include "pybind.h" #include "torch/torch.h" @@ -2227,8 +2228,9 @@ void NVFP4Quantizer::quantize_impl(const TensorWrapper& input, TensorWrapper& ou quant_config.set_noop_tensor(noop_flag->data()); quant_config_columnwise.set_noop_tensor(noop_flag->data()); } - quant_config.set_nvfp4_2d_quantization(this->with_2d_quantization); - quant_config.set_stochastic_rounding(this->stochastic_rounding); + nvfp4_1x64::config_apply(quant_config, this->with_2d_quantization, this->stochastic_rounding, + nvfp4_1x64::local_encode_from_env()); + nvfp4_1x64::require_ok_for_non_split(this->with_rht, this->columnwise_usage, this->stochastic_rounding); // We only need RHT for columnwise usage. // flat first dim and last dim for multi dimensional input From c801693aae73f6ace9ced211ac48c1b8564b1f33 Mon Sep 17 00:00:00 2001 From: Cael Ling Date: Sun, 26 Apr 2026 22:45:37 -0700 Subject: [PATCH 02/10] Add bit-exact reference + tests for NVFP4 rowwise 1x64 local-encode Signed-off-by: Cael Ling --- .../nvfp4/test_nvfp4_1x64_quantize_exact.py | 194 ++++++++++++++++ .../custom_recipes/quantization_nvfp4_1x64.py | 215 ++++++++++++++++++ 2 files changed, 409 insertions(+) create mode 100644 tests/pytorch/nvfp4/test_nvfp4_1x64_quantize_exact.py create mode 100644 transformer_engine/pytorch/custom_recipes/quantization_nvfp4_1x64.py diff --git a/tests/pytorch/nvfp4/test_nvfp4_1x64_quantize_exact.py b/tests/pytorch/nvfp4/test_nvfp4_1x64_quantize_exact.py new file mode 100644 index 0000000000..694382970f --- /dev/null +++ b/tests/pytorch/nvfp4/test_nvfp4_1x64_quantize_exact.py @@ -0,0 +1,194 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Bit-exact tests for the NVFP4 rowwise 1x64 local-encode CUDA kernel. + +Methodology mirrors ``test_nvfp4_quantize_exact.py``: invoke the kernel +through ``NVFP4Quantizer`` (with the ``NVTE_NVFP4_ROWWISE_1X64_LOCAL_ENCODE=1`` +env flag enabling the alternate dispatch), and compare the resulting +``(qx, sx, amax)`` triple against the pure-PyTorch oracle +``NVFP4Quantizer1x64Ref`` byte-for-byte (``atol=rtol=0``). +""" + +import pytest +import torch + +import transformer_engine.pytorch as te +import transformer_engine_torch as tex +from transformer_engine.pytorch import NVFP4Quantizer +from transformer_engine.pytorch.custom_recipes.quantization_nvfp4_1x64 import ( + NVFP4Quantizer1x64Ref, +) + + +recipe_available, reason_for_no_recipe = te.is_nvfp4_available(return_reason=True) + + +def unpack_fp4(x: torch.Tensor) -> torch.Tensor: + """Unpack two FP4 values per byte into one ``uint8`` value per element. + + Identical to the helper in ``test_nvfp4_quantize_exact.py`` -- duplicated + here so the two test suites stay independent. + """ + repeated = x.repeat_interleave(2, dim=1) + repeated[:, 0::2] &= 0x0F + repeated[:, 1::2] >>= 4 + return repeated + + +def _check_quantization_1x64_versus_reference( + x_dtype: torch.dtype, + M: int, + N: int, +) -> None: + """Quantize ``(M, N)`` random input through both kernel and reference and + assert the rowwise data, scale, and global amax all match bit-exactly.""" + te_dtype = tex.DType.kFloat4E2M1 + device = "cuda" + + seed = 0 + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + x = torch.randn((M, N), dtype=x_dtype, device=device) + + # Kernel path. ``with_rht=False`` and ``with_2d_quantization=False`` are + # required by the 1x64 dispatch's preconditions; ``columnwise=False`` + # because the kernel does not produce a transposed output. + quantizer = NVFP4Quantizer( + fp4_dtype=te_dtype, + rowwise=True, + columnwise=False, + with_amax_reduction=False, + amax_reduction_group=None, + with_rht=False, + with_post_rht_amax=False, + with_2d_quantization=False, + ) + x_nvfp4_sut = quantizer(x) + + qx = x_nvfp4_sut._rowwise_data.view(dtype=torch.uint8) + sx = x_nvfp4_sut._rowwise_scale_inv + qx_amax = x_nvfp4_sut._amax_rowwise + + # Reference path. + ref = NVFP4Quantizer1x64Ref().quantize(x) + qx_ref = unpack_fp4(ref.data.view(dtype=torch.uint8)) + sx_ref = ref.scale.view(dtype=torch.uint8) + ref_amax = ref.global_amax_row + + qx = unpack_fp4(qx) + + # The kernel may pad qx/sx to alignment boundaries; only the unpadded + # prefix is meaningfully written. Slice both sides down to the reference + # shape before comparing (the reference returns un-padded tensors). + qx_valid = qx[: qx_ref.shape[0], : qx_ref.shape[1]] + sx_valid = sx[: sx_ref.shape[0], : sx_ref.shape[1]] + + torch.testing.assert_close(qx_valid, qx_ref, atol=0.0, rtol=0.0) + torch.testing.assert_close(sx_valid, sx_ref, atol=0.0, rtol=0.0) + torch.testing.assert_close(qx_amax, ref_amax, atol=0.0, rtol=0.0) + + +@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe) +@pytest.mark.parametrize( + "M, N", + [ + # K is a multiple of WINDOW_K=64 + (128, 128), + (256, 256), + (256, 512), + (1024, 256), + (256, 1024), + (2048, 2048), + (1024, 2048), + # K is a multiple of BLOCK_K=16 but not of WINDOW_K -- exercises the + # partial-last-window path in the kernel. + (256, 80), + (256, 272), + (256, 336), + ], +) +@pytest.mark.parametrize("x_dtype", [torch.float32, torch.bfloat16], ids=str) +def test_nvfp4_1x64_quantize_versus_reference( + monkeypatch, + x_dtype: torch.dtype, + M: int, + N: int, +) -> None: + """Random-input bit-exact test across a representative shape grid.""" + monkeypatch.setenv("NVTE_NVFP4_ROWWISE_1X64_LOCAL_ENCODE", "1") + monkeypatch.setenv("NVTE_NVFP4_DISABLE_RHT", "1") + _check_quantization_1x64_versus_reference(x_dtype, M, N) + + +@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe) +@pytest.mark.parametrize("x_dtype", [torch.float32, torch.bfloat16], ids=str) +def test_nvfp4_1x64_quantize_extrema(monkeypatch, x_dtype: torch.dtype) -> None: + """Stress the saturating-cast and ``S_enc`` clamps with extreme inputs. + + Inputs: + * an outlier-heavy row (one ~FP4_MAX value per 1x64 window plus tiny + noise) -- exercises the ``s_dec`` saturation branch; + * an all-zero region -- exercises the ``tile_amax == 0`` fallback that + promotes ``S_enc`` to 1.0; + * a uniform constant row -- ``s_dec`` should be the same E4M3 byte for + every block in the row. + """ + monkeypatch.setenv("NVTE_NVFP4_ROWWISE_1X64_LOCAL_ENCODE", "1") + monkeypatch.setenv("NVTE_NVFP4_DISABLE_RHT", "1") + + device = "cuda" + M, N = 64, 256 + + torch.manual_seed(0) + torch.cuda.manual_seed(0) + x = torch.randn((M, N), dtype=x_dtype, device=device) * 0.01 + + # Row 0: per-window outliers. + for w in range(N // 64): + x[0, w * 64] = 5.5 + + # Row 1: all zero (degenerate window amax path). + x[1] = 0.0 + + # Row 2: uniform constant. + x[2] = 0.75 + + _check_quantization_1x64_versus_reference_with_input(x) + + +def _check_quantization_1x64_versus_reference_with_input(x: torch.Tensor) -> None: + """Variant of the random-input checker that consumes a caller-provided + tensor (used by the extrema test).""" + te_dtype = tex.DType.kFloat4E2M1 + + quantizer = NVFP4Quantizer( + fp4_dtype=te_dtype, + rowwise=True, + columnwise=False, + with_amax_reduction=False, + amax_reduction_group=None, + with_rht=False, + with_post_rht_amax=False, + with_2d_quantization=False, + ) + x_nvfp4_sut = quantizer(x) + + qx = x_nvfp4_sut._rowwise_data.view(dtype=torch.uint8) + sx = x_nvfp4_sut._rowwise_scale_inv + qx_amax = x_nvfp4_sut._amax_rowwise + + ref = NVFP4Quantizer1x64Ref().quantize(x) + qx_ref = unpack_fp4(ref.data.view(dtype=torch.uint8)) + sx_ref = ref.scale.view(dtype=torch.uint8) + ref_amax = ref.global_amax_row + + qx = unpack_fp4(qx) + + qx_valid = qx[: qx_ref.shape[0], : qx_ref.shape[1]] + sx_valid = sx[: sx_ref.shape[0], : sx_ref.shape[1]] + + torch.testing.assert_close(qx_valid, qx_ref, atol=0.0, rtol=0.0) + torch.testing.assert_close(sx_valid, sx_ref, atol=0.0, rtol=0.0) + torch.testing.assert_close(qx_amax, ref_amax, atol=0.0, rtol=0.0) diff --git a/transformer_engine/pytorch/custom_recipes/quantization_nvfp4_1x64.py b/transformer_engine/pytorch/custom_recipes/quantization_nvfp4_1x64.py new file mode 100644 index 0000000000..e15ad76890 --- /dev/null +++ b/transformer_engine/pytorch/custom_recipes/quantization_nvfp4_1x64.py @@ -0,0 +1,215 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Reference implementation for NVFP4 rowwise 1x64 local-encode quantization. + +The hierarchical 1x64 + 1x16 scheme replaces the per-tensor encoding scaling +factor used by stock NVFP4 with a per-1x64-K-window scaling factor; the four +1x16 sub-blocks inside a window share their parent ``S_enc``. The CUDA kernel +that implements this lives in +``transformer_engine/common/cast/nvfp4/quantize_nvfp4_1x64_rowwise.cu`` and is +dispatched when ``NVTE_NVFP4_ROWWISE_1X64_LOCAL_ENCODE=1``. + +This file mirrors that kernel's arithmetic in pure PyTorch so tests can +compare the kernel's output byte-for-byte against a Python oracle (the same +bit-exact methodology used by ``NVFP4QuantizerRef`` for the production NVFP4 +path). The arithmetic ordering and intermediate clamps are chosen to match +what the kernel does: + +* ``S_enc_tile = (FP8_MAX*FP4_MAX) / max(tile_amax, 1e-12)`` clamped to + ``fp32_max``; +* ``s_dec = saturating_cast(vec_max * S_enc_tile / FP4_MAX)`` (the + ``1/FP4_MAX`` is folded into the ``S_enc`` multiplier to match + ``compute_decoding_scaling_factor``); +* ``block_scale = S_enc_tile / fp32(s_dec)`` (matches ``__fdiv_rn`` in the + kernel); +* ``q = round_fp4_satfinite(x_fp32 * block_scale)`` with values clamped to + ``[-FP4_MAX, FP4_MAX]`` before packing. + +Only the rowwise, non-RHT, non-2D, non-stochastic-rounding path is supported, +matching the kernel's preconditions. +""" + +from __future__ import annotations + +import dataclasses +from typing import Optional, Tuple + +import torch + +from transformer_engine.pytorch.custom_recipes.quantization_nvfp4 import cast_to_fp4x2 + +# Window/block geometry is fixed by the kernel design. +WINDOW_K: int = 64 +BLOCK_K: int = 16 +BLOCKS_PER_WINDOW: int = WINDOW_K // BLOCK_K # 4 + +# E2M1 / E4M3 numeric extrema (matches ``TypeExtrema`` in core_nvfp4.cuh). +FLOAT4_E2M1_MAX: float = 6.0 +FLOAT8_E4M3_MAX: float = 448.0 + +# Matches the kernel's ``fmaxf(tile_amax, 1e-12f)`` clamp guarding the divisor +# of ``compute_global_encode_scaling_factor_FP4``. +_TILE_AMAX_FLOOR: float = 1e-12 + + +@dataclasses.dataclass +class RefNVFP4Tensor1x64: + """Container for the rowwise 1x64 reference output. + + Mirrors the subset of attributes that the bit-exact test inspects. + Naming follows ``quantization_nvfp4.RefNVFP4Tensor`` so the test reads the + same way as ``check_quantization_nvfp4_versus_reference``. + + Attributes + ---------- + data: + Packed FP4 bytes, ``(M, N // 2)`` ``uint8``. + scale: + Per-1x16-block decode scale (E4M3), ``(M, N // 16)`` ``float8_e4m3fn``. + global_amax_row: + Global tensor amax (1-D, single fp32 element). Equals the result of + the kernel's ``atomicMaxFloat`` over all per-tile amax values. + """ + + data: Optional[torch.Tensor] = None + scale: Optional[torch.Tensor] = None + global_amax_row: Optional[torch.Tensor] = None + + +class NVFP4Quantizer1x64Ref: + """Reference implementation of the rowwise 1x64 local-encode kernel. + + The constructor takes no parameters because the kernel itself does not + expose any -- columnwise output, RHT, 2D scaling, and stochastic rounding + are all rejected at dispatch time. Surfacing those as ctor flags here + would only invite the test to drift away from the kernel's actual + capabilities. + """ + + def __init__(self) -> None: + # No configurable knobs; see class docstring. + pass + + @staticmethod + def _quantize_rowwise(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Run the 1x64 reference math on a 2D input. + + Returns + ------- + ``(qx, sx, global_amax)`` where shapes match the kernel's compact + rowwise layout: ``qx`` is ``(M, N // 2)`` ``uint8``, ``sx`` is + ``(M, N // BLOCK_K)`` ``float8_e4m3fn``, and ``global_amax`` is + ``(1,)`` ``float32``. + """ + if x.ndim != 2: + raise ValueError(f"NVFP4Quantizer1x64Ref expects a 2D tensor, got {x.ndim}D") + M, N = x.shape + if N % BLOCK_K != 0: + raise ValueError( + f"N={N} must be a multiple of BLOCK_K={BLOCK_K} (kernel hard requirement)" + ) + + device = x.device + fp32_max = torch.tensor( + torch.finfo(torch.float32).max, device=device, dtype=torch.float32 + ) + fp4_max = torch.tensor(FLOAT4_E2M1_MAX, device=device, dtype=torch.float32) + fp8_max = torch.tensor(FLOAT8_E4M3_MAX, device=device, dtype=torch.float32) + + # Pad K up to a multiple of WINDOW_K so the reshape into windows is + # well-defined. The kernel itself supports a partial last window via + # the ``win_len`` clamp; we emulate that by zero-padding here and + # trimming the padded columns out of qx/sx at the end (the padded + # blocks are uninitialised in the kernel output, so the test compares + # only the un-padded prefix). + pad_n = (WINDOW_K - N % WINDOW_K) % WINDOW_K + if pad_n > 0: + x_padded = torch.nn.functional.pad(x, (0, pad_n), mode="constant", value=0.0) + else: + x_padded = x.contiguous() + Np = x_padded.shape[1] + n_win = Np // WINDOW_K + n_blk = Np // BLOCK_K + + x_padded_fp32 = x_padded.to(torch.float32) + x_win = x_padded_fp32.view(M, n_win, WINDOW_K) + x_blk = x_padded_fp32.view(M, n_blk, BLOCK_K) + + # 1x64 tile amax. The kernel applies fmaxf(tile_amax, 1e-12f) to the + # divisor; do the same here. ``S_enc_tile`` is then computed exactly + # like ``compute_global_encode_scaling_factor_FP4``: divide, clamp to + # fp32_max, and fall back to 1.0 if the divisor or quotient is zero + # (the latter branch is dead given the floor but is mirrored for + # parity). + tile_amax = torch.amax(torch.abs(x_win), dim=-1, keepdim=True) # (M, n_win, 1) + tile_amax_safe = torch.clamp(tile_amax, min=_TILE_AMAX_FLOOR) + + S_enc_tile = (fp8_max * fp4_max) / tile_amax_safe + S_enc_tile = torch.minimum(S_enc_tile, fp32_max) + S_enc_tile = torch.where( + (tile_amax_safe == 0) | (S_enc_tile == 0), + torch.ones_like(S_enc_tile), + S_enc_tile, + ) + + # Fold (1 / fp4_max) into the multiplier the same way the kernel does + # in ``compute_decoding_scaling_factor`` (``S_enc * fp4_max_inv``). + # Keeping the operation order identical is what makes the resulting + # E4M3 scale bit-exact with the kernel. + S_enc_tile_mul_inv6 = S_enc_tile * torch.reciprocal(fp4_max) + + # 1x16 block amax and per-block S_enc broadcast. Each 1x64 window + # spans exactly BLOCKS_PER_WINDOW (=4) consecutive 1x16 sub-blocks, so + # repeat_interleave along the block axis aligns one S_enc_tile to + # every block inside that window. + vec_max = torch.amax(torch.abs(x_blk), dim=-1, keepdim=True) # (M, n_blk, 1) + S_enc_per_blk = S_enc_tile.repeat_interleave(BLOCKS_PER_WINDOW, dim=1) + S_enc_per_blk_mul = S_enc_tile_mul_inv6.repeat_interleave(BLOCKS_PER_WINDOW, dim=1) + + # decode_scale = saturating_cast(vec_max * S_enc_tile / 6). + # The kernel does not clamp before the cast; we do, because PyTorch's + # ``.to(float8_e4m3fn)`` does not match CUDA's saturating cast for + # values above FP8_MAX. After the explicit clamp the two paths agree. + decode_scale_fp32 = vec_max * S_enc_per_blk_mul + decode_scale_fp32 = torch.minimum(decode_scale_fp32, fp32_max) + decode_scale_fp32 = torch.clamp(decode_scale_fp32, min=-fp8_max, max=fp8_max) + decode_scale_e4m3 = decode_scale_fp32.to(torch.float8_e4m3fn) + decode_scale_back_fp32 = decode_scale_e4m3.to(torch.float32) + + # block_scale = S_enc_tile / s_dec, matching ``__fdiv_rn`` in the + # kernel. Padded sub-blocks have vec_max == 0 hence s_dec == 0, which + # would yield +inf here and propagate NaN through the downstream + # multiply. To keep the division warning-free we replace zero + # divisors with 1.0, divide, then mask the result back to zero -- the + # padded slots are trimmed out of the final comparison either way. + zero_blk = decode_scale_back_fp32 == 0 + denom = torch.where(zero_blk, torch.ones_like(decode_scale_back_fp32), decode_scale_back_fp32) + encode_scale = S_enc_per_blk / denom + encode_scale = torch.where(zero_blk, torch.zeros_like(encode_scale), encode_scale) + encode_scale = torch.minimum(encode_scale, fp32_max) + + # Apply scale, clamp to FP4 range, and pack two FP4 values per byte. + scaled_x = x_blk * encode_scale + clipped_x = torch.clamp(scaled_x, -fp4_max, fp4_max).reshape(M, Np) + qx_packed_padded = cast_to_fp4x2(clipped_x) # (M, Np // 2) + + sx_padded = decode_scale_e4m3.squeeze(-1) # (M, n_blk) + + # Trim the K-direction padding so the returned tensors describe only + # positions the kernel actually wrote to. + qx = qx_packed_padded[:, : N // 2].contiguous() + sx = sx_padded[:, : N // BLOCK_K].contiguous() + + # ``output.amax`` in the kernel accumulates ``atomicMaxFloat`` over + # every per-tile amax, which is mathematically max-of-maxes -- i.e. + # the global tensor amax. Compute that directly here. + global_amax = torch.amax(torch.abs(x.to(torch.float32))).reshape(1) + + return qx, sx, global_amax + + def quantize(self, tensor: torch.Tensor) -> RefNVFP4Tensor1x64: + """Quantize ``tensor`` and return a ``RefNVFP4Tensor1x64``.""" + qx, sx, global_amax = self._quantize_rowwise(tensor) + return RefNVFP4Tensor1x64(data=qx, scale=sx, global_amax_row=global_amax) From 7a39ad6c9cf46a6f46bd611e190d7d600900101c Mon Sep 17 00:00:00 2001 From: Cael Ling Date: Mon, 27 Apr 2026 01:29:23 -0700 Subject: [PATCH 03/10] Add accuracy comparison test for NVFP4 rowwise 1x64 vs per-tensor Signed-off-by: Cael Ling --- .../pytorch/nvfp4/test_nvfp4_1x64_accuracy.py | 315 ++++++++++++++++++ 1 file changed, 315 insertions(+) create mode 100644 tests/pytorch/nvfp4/test_nvfp4_1x64_accuracy.py diff --git a/tests/pytorch/nvfp4/test_nvfp4_1x64_accuracy.py b/tests/pytorch/nvfp4/test_nvfp4_1x64_accuracy.py new file mode 100644 index 0000000000..d437f9c365 --- /dev/null +++ b/tests/pytorch/nvfp4/test_nvfp4_1x64_accuracy.py @@ -0,0 +1,315 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Accuracy comparison: hierarchical 1x64 + 1x16 vs per-tensor + 1x16. + +The hypothesis under test is that per-1x64-window ``S_enc`` (the hierarchical +scheme) reconstructs ``x`` at least as accurately as per-tensor ``S_enc`` +(the production NVFP4 scheme), and strictly better when the K-direction +magnitude varies meaningfully across windows -- which is the case the +hierarchical scheme is built for. + +Methodology +----------- +This is a *spec-level* accuracy test, not a kernel test: + +1. Quantize the same input through two pure-PyTorch references -- + ``NVFP4QuantizerRef`` (per-tensor) and ``NVFP4Quantizer1x64Ref`` + (per-1x64-window) -- to obtain ``(qx, sx)`` for each. +2. Dequantize both back to fp32 using each scheme's *own* ``S_enc``: + ``x_recon = q * s_dec_fp32 / S_enc``. For the per-tensor scheme + ``S_enc`` is a scalar; for 1x64 it is per-window. (Using the per-tensor + ``S_enc`` to dequantize a 1x64-encoded tensor would re-introduce the + GEMM-mismatch bug -- a separate concern that is documented elsewhere.) +3. Compute reconstruction error metrics (RMSE, max abs, Frobenius-relative) + against the original fp32 input. +4. Assert ``rmse_1x64`` is at most a small slack worse than ``rmse_pt`` on + benign inputs, and strictly better on inputs with per-window dynamic + range. Numbers are also printed so a regression failure is diagnostic. + +Because the bit-exact tests in ``test_nvfp4_1x64_quantize_exact.py`` already +certify ``NVFP4Quantizer1x64Ref`` reproduces the CUDA kernel byte-for-byte, +any accuracy conclusion we draw at the reference level transfers directly +to the kernel. +""" + +from __future__ import annotations + +from typing import Tuple + +import pytest +import torch + +from transformer_engine.pytorch.custom_recipes import utils +from transformer_engine.pytorch.custom_recipes.quantization_nvfp4 import ( + NVFP4QuantizerRef, + cast_from_fp4x2, +) +from transformer_engine.pytorch.custom_recipes.quantization_nvfp4_1x64 import ( + BLOCK_K, + FLOAT4_E2M1_MAX, + FLOAT8_E4M3_MAX, + NVFP4Quantizer1x64Ref, + WINDOW_K, +) + + +_SAFE_AMAX_FLOOR = 1e-12 + + +def _per_tensor_s_enc(x: torch.Tensor) -> torch.Tensor: + """Per-tensor encoding scaling factor used by stock NVFP4. + + Returns a per-element broadcast of shape ``(M, N)`` so the dequant + formula can be expressed elementwise without further care for layout. + """ + M, N = x.shape + fp32_max = torch.tensor(torch.finfo(torch.float32).max, device=x.device, dtype=torch.float32) + global_amax = torch.amax(torch.abs(x.to(torch.float32))) + s = (FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.clamp(global_amax, min=_SAFE_AMAX_FLOOR) + s = torch.minimum(s, fp32_max) + if float(s.item()) == 0.0: + s = torch.ones_like(s) + return s.expand(M, N).contiguous() + + +def _per_window_s_enc(x: torch.Tensor) -> torch.Tensor: + """Per-1x64-window encoding scaling factor, broadcast to ``(M, N)``.""" + M, N = x.shape + pad_n = (WINDOW_K - N % WINDOW_K) % WINDOW_K + if pad_n > 0: + x_padded = torch.nn.functional.pad(x, (0, pad_n), mode="constant", value=0.0) + else: + x_padded = x.contiguous() + Np = x_padded.shape[1] + n_win = Np // WINDOW_K + + fp32_max = torch.tensor(torch.finfo(torch.float32).max, device=x.device, dtype=torch.float32) + x_padded_fp32 = x_padded.to(torch.float32).view(M, n_win, WINDOW_K) + tile_amax = torch.amax(torch.abs(x_padded_fp32), dim=-1, keepdim=True) + tile_amax_safe = torch.clamp(tile_amax, min=_SAFE_AMAX_FLOOR) + + s = (FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / tile_amax_safe + s = torch.minimum(s, fp32_max) + s = torch.where( + (tile_amax_safe == 0) | (s == 0), + torch.ones_like(s), + s, + ) + s_per_elt = s.squeeze(-1).repeat_interleave(WINDOW_K, dim=1) + return s_per_elt[:, :N].contiguous() + + +def _dequantize( + qx_packed: torch.Tensor, + sx_e4m3: torch.Tensor, + s_enc_per_elt: torch.Tensor, + M: int, + N: int, +) -> torch.Tensor: + """Inverse of NVFP4 forward quantization: ``q * s_dec_fp32 / S_enc``.""" + q_fp32 = cast_from_fp4x2(qx_packed.view(torch.uint8), torch.float32) + sx_fp32 = sx_e4m3.view(torch.float8_e4m3fn).to(torch.float32) + sx_fp32 = sx_fp32[:M, : N // BLOCK_K] + sx_per_elt = sx_fp32.repeat_interleave(BLOCK_K, dim=1) + return q_fp32 * sx_per_elt / s_enc_per_elt + + +def _recon_per_tensor(x: torch.Tensor) -> torch.Tensor: + """Forward + inverse via the production per-tensor NVFP4 reference.""" + M, N = x.shape + quantizer = NVFP4QuantizerRef( + dtype=utils.Fp4Formats.E2M1, + rowwise=True, + columnwise=False, + pow_2_scales=False, + eps=0.0, + quant_tile_shape=(1, 16), + with_rht=False, + ) + out = quantizer.quantize(x) + s_enc = _per_tensor_s_enc(x) + return _dequantize(out.data, out.scale, s_enc, M, N) + + +def _recon_1x64(x: torch.Tensor) -> torch.Tensor: + """Forward + inverse via the hierarchical 1x64 reference.""" + M, N = x.shape + out = NVFP4Quantizer1x64Ref().quantize(x) + s_enc = _per_window_s_enc(x) + return _dequantize(out.data, out.scale, s_enc, M, N) + + +def _err_metrics(x: torch.Tensor, recon: torch.Tensor) -> Tuple[float, float, float]: + """Return ``(rmse, max_abs_err, frobenius_relative)`` of ``recon - x``.""" + x_fp32 = x.to(torch.float32) + diff = recon.to(torch.float32) - x_fp32 + rmse = torch.sqrt(torch.mean(diff * diff)).item() + max_err = torch.max(torch.abs(diff)).item() + denom = torch.linalg.norm(x_fp32).clamp(min=1e-30) + frob_rel = (torch.linalg.norm(diff) / denom).item() + return rmse, max_err, frob_rel + + +def _gen_gaussian( + M: int, N: int, *, seed: int, device: str, dtype: torch.dtype +) -> torch.Tensor: + """Uniform N(0, 1) -- a benign baseline where both schemes should tie.""" + g = torch.Generator(device=device).manual_seed(seed) + return torch.randn((M, N), generator=g, device=device, dtype=dtype) + + +def _gen_per_window_dynamic_range( + M: int, + N: int, + *, + seed: int, + device: str, + dtype: torch.dtype, + log10_lo: float, + log10_hi: float, +) -> torch.Tensor: + """Each 1x64 window has its own log-uniform magnitude scale. + + This is the scenario the hierarchical scheme is built for: per-window + ``S_enc`` adapts to local magnitude while the per-tensor ``S_enc`` is + pinned to the loudest window and crushes precision in the quiet ones + (or, in the extreme, rounds the quiet windows' E4M3 ``s_dec`` to zero). + + ``log10_lo`` / ``log10_hi`` control the dynamic range: + * ``[-1.5, 0.5]`` -- modest (~30x ratio), small-but-real advantage. + * ``[-5.0, 0.5]`` -- extreme (~3e5 ratio), large advantage. + """ + g = torch.Generator(device=device).manual_seed(seed) + n_win = (N + WINDOW_K - 1) // WINDOW_K + log_scales = torch.empty((M, n_win, 1), device=device, dtype=torch.float32) + log_scales.uniform_(log10_lo, log10_hi, generator=g) + scales = torch.pow(torch.tensor(10.0, device=device, dtype=torch.float32), log_scales) + base = torch.randn( + (M, n_win, WINDOW_K), generator=g, device=device, dtype=torch.float32 + ) + x = (base * scales).reshape(M, n_win * WINDOW_K)[:, :N].contiguous() + return x.to(dtype) + + +_NEEDS_CUDA = pytest.mark.skipif( + not torch.cuda.is_available(), + reason="accuracy comparison runs on CUDA to mirror the rest of the nvfp4 suite", +) + + +@_NEEDS_CUDA +@pytest.mark.parametrize("M, N", [(256, 1024), (512, 2048), (1024, 1024)]) +@pytest.mark.parametrize("x_dtype", [torch.float32, torch.bfloat16], ids=str) +@pytest.mark.parametrize("seed", [0, 1, 2]) +def test_1x64_at_least_as_good_as_per_tensor_on_gaussian( + M: int, N: int, x_dtype: torch.dtype, seed: int, capsys +) -> None: + """On uniform Gaussian inputs the two schemes should be roughly tied. + + ``S_enc_global`` and per-window ``S_enc_tile`` see similar amax values + when the input magnitude is uniform across K, so we expect ratios near + 1.0. We allow a 5% slack to absorb E4M3-rounding noise; a regression + that materially worsens 1x64 accuracy on benign inputs would still be + caught. + """ + device = "cuda" + x = _gen_gaussian(M, N, seed=seed, device=device, dtype=x_dtype) + + rmse_pt, max_pt, fro_pt = _err_metrics(x, _recon_per_tensor(x)) + rmse_1x64, max_1x64, fro_1x64 = _err_metrics(x, _recon_1x64(x)) + + with capsys.disabled(): + print( + f"\n[gaussian {M}x{N} {x_dtype} seed={seed}]" + f" rmse: pt={rmse_pt:.4e} 1x64={rmse_1x64:.4e}" + f" ratio={rmse_1x64 / max(rmse_pt, 1e-30):.3f} |" + f" max_abs: pt={max_pt:.4e} 1x64={max_1x64:.4e} |" + f" frob_rel: pt={fro_pt:.4e} 1x64={fro_1x64:.4e}" + ) + + assert rmse_1x64 <= rmse_pt * 1.05, ( + f"1x64 RMSE unexpectedly worse than per-tensor on uniform input: " + f"rmse_1x64={rmse_1x64:.4e} > 1.05 * rmse_pt={rmse_pt:.4e}" + ) + + +@_NEEDS_CUDA +@pytest.mark.parametrize("M, N", [(256, 1024), (512, 2048)]) +@pytest.mark.parametrize("x_dtype", [torch.float32, torch.bfloat16], ids=str) +@pytest.mark.parametrize("seed", [0, 1, 2]) +def test_1x64_strictly_better_on_extreme_per_window_dynamic_range( + M: int, N: int, x_dtype: torch.dtype, seed: int, capsys +) -> None: + """When per-window magnitude spans ~5 orders of magnitude, 1x64 must win. + + With ``log10_lo = -5`` the loudest-to-quietest window scale ratio is + ~3e5. The per-tensor ``S_enc`` is pinned to the loudest window, which + drives ``s_dec`` for quiet windows toward (or into) E4M3 underflow -- + catastrophic for those positions. The hierarchical scheme assigns each + window its own ``S_enc_tile`` and recovers near-full precision in the + quiet ones, so the overall RMSE drops by at least 2x. + """ + device = "cuda" + x = _gen_per_window_dynamic_range( + M, N, seed=seed, device=device, dtype=x_dtype, log10_lo=-5.0, log10_hi=0.5 + ) + + rmse_pt, max_pt, fro_pt = _err_metrics(x, _recon_per_tensor(x)) + rmse_1x64, max_1x64, fro_1x64 = _err_metrics(x, _recon_1x64(x)) + + with capsys.disabled(): + print( + f"\n[dyn_range_extreme {M}x{N} {x_dtype} seed={seed}]" + f" rmse: pt={rmse_pt:.4e} 1x64={rmse_1x64:.4e}" + f" ratio={rmse_1x64 / max(rmse_pt, 1e-30):.3f} |" + f" max_abs: pt={max_pt:.4e} 1x64={max_1x64:.4e} |" + f" frob_rel: pt={fro_pt:.4e} 1x64={fro_1x64:.4e}" + ) + + assert rmse_1x64 < rmse_pt * 0.5, ( + f"1x64 failed to outperform per-tensor on extreme dynamic-range input " + f"by the expected margin: rmse_1x64={rmse_1x64:.4e} vs" + f" 0.5 * rmse_pt={0.5 * rmse_pt:.4e}" + ) + + +@_NEEDS_CUDA +@pytest.mark.parametrize("M, N", [(256, 1024)]) +@pytest.mark.parametrize("x_dtype", [torch.float32, torch.bfloat16], ids=str) +@pytest.mark.parametrize("seed", [0, 1, 2]) +def test_1x64_at_least_tied_on_modest_per_window_dynamic_range( + M: int, N: int, x_dtype: torch.dtype, seed: int, capsys +) -> None: + """A moderate per-window dynamic range still favours 1x64. + + Ratio range of ~30x is well within E4M3's representable scale range, + so the per-tensor scheme does not catastrophically underflow. The + advantage from local ``S_enc`` is therefore smaller -- but still + present. We assert merely "no worse" plus a generous 5% slack; the + strict inequality test above is the one that demonstrates the win, + while this case ensures the win does not invert when the dynamic + range shrinks. + """ + device = "cuda" + x = _gen_per_window_dynamic_range( + M, N, seed=seed, device=device, dtype=x_dtype, log10_lo=-1.5, log10_hi=0.5 + ) + + rmse_pt, max_pt, fro_pt = _err_metrics(x, _recon_per_tensor(x)) + rmse_1x64, max_1x64, fro_1x64 = _err_metrics(x, _recon_1x64(x)) + + with capsys.disabled(): + print( + f"\n[dyn_range_modest {M}x{N} {x_dtype} seed={seed}]" + f" rmse: pt={rmse_pt:.4e} 1x64={rmse_1x64:.4e}" + f" ratio={rmse_1x64 / max(rmse_pt, 1e-30):.3f} |" + f" max_abs: pt={max_pt:.4e} 1x64={max_1x64:.4e} |" + f" frob_rel: pt={fro_pt:.4e} 1x64={fro_1x64:.4e}" + ) + + assert rmse_1x64 <= rmse_pt * 1.05, ( + f"1x64 unexpectedly worse than per-tensor on modest dynamic-range " + f"input: rmse_1x64={rmse_1x64:.4e} > 1.05 * rmse_pt={rmse_pt:.4e}" + ) From f80c36b35d04a92848b79423a97fa7fa6105eb36 Mon Sep 17 00:00:00 2001 From: Cael Ling Date: Mon, 27 Apr 2026 04:34:24 -0700 Subject: [PATCH 04/10] Simplify accuracy advantage test assertion Signed-off-by: Cael Ling --- .../pytorch/nvfp4/test_nvfp4_1x64_accuracy.py | 122 ++++++++++++++---- 1 file changed, 97 insertions(+), 25 deletions(-) diff --git a/tests/pytorch/nvfp4/test_nvfp4_1x64_accuracy.py b/tests/pytorch/nvfp4/test_nvfp4_1x64_accuracy.py index d437f9c365..79251ba089 100644 --- a/tests/pytorch/nvfp4/test_nvfp4_1x64_accuracy.py +++ b/tests/pytorch/nvfp4/test_nvfp4_1x64_accuracy.py @@ -172,14 +172,17 @@ def _gen_per_window_dynamic_range( ) -> torch.Tensor: """Each 1x64 window has its own log-uniform magnitude scale. - This is the scenario the hierarchical scheme is built for: per-window - ``S_enc`` adapts to local magnitude while the per-tensor ``S_enc`` is - pinned to the loudest window and crushes precision in the quiet ones - (or, in the extreme, rounds the quiet windows' E4M3 ``s_dec`` to zero). - - ``log10_lo`` / ``log10_hi`` control the dynamic range: - * ``[-1.5, 0.5]`` -- modest (~30x ratio), small-but-real advantage. - * ``[-5.0, 0.5]`` -- extreme (~3e5 ratio), large advantage. + This generator is used to verify that 1x64 does **not regress** versus + per-tensor when per-window magnitude varies but stays inside E4M3's + representable range. It deliberately does not try to demonstrate a + large advantage: per-block ``s_dec`` already adapts to ``vec_max`` in + both schemes, so as long as ``s_dec`` does not underflow E4M3, FP4 + resolution per element is ``vec_max / 12`` regardless of the ``S_enc`` + choice, and the two schemes tie up to E4M3 rounding noise. + + To actually demonstrate the 1x64 advantage in absolute RMSE terms, see + ``_gen_sparse_extreme_outliers`` -- that distribution forces prod's + ``s_dec`` into the underflow regime for the bulk of the tensor. """ g = torch.Generator(device=device).manual_seed(seed) n_win = (N + WINDOW_K - 1) // WINDOW_K @@ -193,6 +196,61 @@ def _gen_per_window_dynamic_range( return x.to(dtype) +def _gen_sparse_extreme_outliers( + M: int, + N: int, + *, + seed: int, + device: str, + dtype: torch.dtype, + outlier_mag: float = 1.0e6, + n_outlier_windows: int = 4, +) -> torch.Tensor: + """``N(0, 1)`` background plus a handful of extreme outlier elements. + + This is the distribution where 1x64's advantage is sharp in absolute + RMSE. The mechanism: + + * The per-tensor scheme derives ``S_enc_global`` from the global amax, + which the outliers drag up to ~``outlier_mag``. With + ``outlier_mag = 1e6``, ``S_enc_global ~ 2.7e-3``; for a non-outlier + 1x16 block (``vec_max ~ 2``) the resulting ``s_dec_fp32`` is + ``~9e-4``, **below E4M3's smallest subnormal (~2e-3)**. ``s_dec`` + rounds to zero, so reconstruction of every non-outlier block is + identically zero -- catastrophic loss for the bulk of the tensor. + + * The hierarchical scheme uses ``S_enc_tile`` per 1x64 window. In the + ~``M*n_win - n_outlier_windows`` non-outlier windows ``tile_amax`` + sees only ``N(0, 1)``, so ``S_enc_tile ~ 900``, ``s_dec`` stays well + inside E4M3, and FP4 quantization runs at standard precision. + Loss is restricted to the (small) bulk inside outlier-containing + windows. + + The expected RMSE ratio ``rmse_1x64 / rmse_pt`` is ``~ 0.06`` -- well + below the 0.5 threshold the test asserts. + """ + g = torch.Generator(device=device).manual_seed(seed) + x = torch.randn((M, N), generator=g, device=device, dtype=torch.float32) + + n_win = N // WINDOW_K + if n_win == 0: + # Tensor is narrower than one full 1x64 window; the scheme degenerates + # and there is no meaningful advantage to test. + return x.to(dtype) + + total_windows = M * n_win + n_outlier_windows = min(n_outlier_windows, total_windows) + flat_idx = torch.randperm(total_windows, generator=g, device=device)[:n_outlier_windows] + outlier_rows = flat_idx // n_win + outlier_cols = (flat_idx % n_win) * WINDOW_K + signs = torch.randint( + 0, 2, (n_outlier_windows,), generator=g, device=device, dtype=torch.int32 + ).to(torch.float32) * 2 - 1 + x[outlier_rows, outlier_cols] = signs * float(outlier_mag) + + return x.to(dtype) + + _NEEDS_CUDA = pytest.mark.skipif( not torch.cuda.is_available(), reason="accuracy comparison runs on CUDA to mirror the rest of the nvfp4 suite", @@ -239,39 +297,53 @@ def test_1x64_at_least_as_good_as_per_tensor_on_gaussian( @pytest.mark.parametrize("M, N", [(256, 1024), (512, 2048)]) @pytest.mark.parametrize("x_dtype", [torch.float32, torch.bfloat16], ids=str) @pytest.mark.parametrize("seed", [0, 1, 2]) -def test_1x64_strictly_better_on_extreme_per_window_dynamic_range( +def test_1x64_better_than_per_tensor_on_sparse_extreme_outliers( M: int, N: int, x_dtype: torch.dtype, seed: int, capsys ) -> None: - """When per-window magnitude spans ~5 orders of magnitude, 1x64 must win. - - With ``log10_lo = -5`` the loudest-to-quietest window scale ratio is - ~3e5. The per-tensor ``S_enc`` is pinned to the loudest window, which - drives ``s_dec`` for quiet windows toward (or into) E4M3 underflow -- - catastrophic for those positions. The hierarchical scheme assigns each - window its own ``S_enc_tile`` and recovers near-full precision in the - quiet ones, so the overall RMSE drops by at least 2x. + """Sparse extreme outliers force prod's ``s_dec`` into E4M3 underflow. + + Per-block ``s_dec`` already adapts to local ``vec_max`` in both + schemes, so within E4M3's representable range FP4 resolution per + element is ``~ vec_max / 12`` independent of the ``S_enc`` choice -- + the schemes tie. The 1x64 advantage in *absolute* RMSE terms shows up + only when prod's E4M3 ``s_dec`` underflows, and only when those + underflowed positions also dominate ``mean(x^2)``. + + To put both conditions in effect simultaneously, this test uses an + ``N(0, 1)`` bulk plus a handful of outliers of magnitude ``1e6``: + + * ``S_enc_global = (FP8_MAX * FP4_MAX) / amax ~ 2.7e-3`` -- + ``s_dec`` for non-outlier blocks underflows; reconstruction of the + entire bulk collapses to zero. ``rmse_pt ~ 1.0``. + + * ``S_enc_tile`` for ~99% of windows is set by ``tile_amax ~ 3`` + (no outlier present); FP4 quantization runs at standard precision. + ``rmse_1x64 ~ 0.06``. + + The assertion only requires 1x64 to be strictly better than + per-tensor; the comparison values are printed for each parametrized + case so the actual margin (empirically ``rmse_1x64 / rmse_pt`` is + around ``0.05 - 0.10``) is visible in the test log. """ device = "cuda" - x = _gen_per_window_dynamic_range( - M, N, seed=seed, device=device, dtype=x_dtype, log10_lo=-5.0, log10_hi=0.5 - ) + x = _gen_sparse_extreme_outliers(M, N, seed=seed, device=device, dtype=x_dtype) rmse_pt, max_pt, fro_pt = _err_metrics(x, _recon_per_tensor(x)) rmse_1x64, max_1x64, fro_1x64 = _err_metrics(x, _recon_1x64(x)) with capsys.disabled(): print( - f"\n[dyn_range_extreme {M}x{N} {x_dtype} seed={seed}]" + f"\n[sparse_outliers {M}x{N} {x_dtype} seed={seed}]" f" rmse: pt={rmse_pt:.4e} 1x64={rmse_1x64:.4e}" f" ratio={rmse_1x64 / max(rmse_pt, 1e-30):.3f} |" f" max_abs: pt={max_pt:.4e} 1x64={max_1x64:.4e} |" f" frob_rel: pt={fro_pt:.4e} 1x64={fro_1x64:.4e}" ) - assert rmse_1x64 < rmse_pt * 0.5, ( - f"1x64 failed to outperform per-tensor on extreme dynamic-range input " - f"by the expected margin: rmse_1x64={rmse_1x64:.4e} vs" - f" 0.5 * rmse_pt={0.5 * rmse_pt:.4e}" + assert rmse_1x64 < rmse_pt, ( + f"1x64 was not strictly better than per-tensor on " + f"sparse-extreme-outlier input: " + f"rmse_1x64={rmse_1x64:.4e} >= rmse_pt={rmse_pt:.4e}" ) From 72180d98caa3a9e2d8aa9999ab8fbef83fccd050 Mon Sep 17 00:00:00 2001 From: Cael Ling Date: Mon, 27 Apr 2026 06:01:56 -0700 Subject: [PATCH 05/10] Fix NVFP4 rowwise 1x64 kernel zero-block NaN path Signed-off-by: Cael Ling --- .../common/cast/nvfp4/quantize_nvfp4_1x64_rowwise.cu | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/transformer_engine/common/cast/nvfp4/quantize_nvfp4_1x64_rowwise.cu b/transformer_engine/common/cast/nvfp4/quantize_nvfp4_1x64_rowwise.cu index 72dde17e7a..70985a4965 100644 --- a/transformer_engine/common/cast/nvfp4/quantize_nvfp4_1x64_rowwise.cu +++ b/transformer_engine/common/cast/nvfp4/quantize_nvfp4_1x64_rowwise.cu @@ -82,7 +82,14 @@ __global__ void __launch_bounds__(64) nvfp4_rowwise_1x64_per_tile( bmx = fmaxf(bmx, fabsf(v)); } const fp8e4m3 s_dec = compute_decoding_scaling_factor(bmx, S_enc_tile); - const float block_scale = __fdiv_rn(S_enc_tile, static_cast(s_dec)); + // Match the reference's all-zero-block branch: when ``bmx == 0`` the + // FP8 cast saturates ``s_dec`` to 0, so the naive ``S_enc / s_dec`` + // would yield ``+inf`` and the subsequent ``0 * +inf`` a NaN that + // ``cvt.rn.satfinite.e2m1x4.f32`` resolves to ``+FP4_MAX`` (0x7) on + // SM10, not 0x0. Short-circuiting here keeps the kernel byte-exact + // with ``NVFP4Quantizer1x64Ref``. + const float s_dec_f = static_cast(s_dec); + const float block_scale = (s_dec_f == 0.f) ? 0.f : __fdiv_rn(S_enc_tile, s_dec_f); const int c16 = cs / 16; row_scales[static_cast(r) * scale_stride + static_cast(c16)] = s_dec; From 92350db1701e5ca96e80ffb33041e06e6241e0fb Mon Sep 17 00:00:00 2001 From: Cael Ling Date: Tue, 28 Apr 2026 02:46:28 -0700 Subject: [PATCH 06/10] Extend NVFP4 1x64 cast to fused rowwise + columnwise output Signed-off-by: Cael Ling --- .../nvfp4/test_nvfp4_1x64_quantize_exact.py | 273 +++++++++------- transformer_engine/common/CMakeLists.txt | 2 +- .../common/cast/dispatch/quantize.cuh | 18 +- .../common/cast/nvfp4/quantize_nvfp4_1x64.cu | 297 ++++++++++++++++++ .../common/cast/nvfp4/quantize_nvfp4_1x64.cuh | 45 +++ .../cast/nvfp4/quantize_nvfp4_1x64_rowwise.cu | 168 ---------- .../nvfp4/quantize_nvfp4_1x64_rowwise.cuh | 31 -- transformer_engine/pytorch/csrc/nvfp4_1x64.h | 30 +- transformer_engine/pytorch/csrc/quantizer.cpp | 52 ++- .../custom_recipes/quantization_nvfp4_1x64.py | 186 ++++++----- 10 files changed, 671 insertions(+), 431 deletions(-) create mode 100644 transformer_engine/common/cast/nvfp4/quantize_nvfp4_1x64.cu create mode 100644 transformer_engine/common/cast/nvfp4/quantize_nvfp4_1x64.cuh delete mode 100644 transformer_engine/common/cast/nvfp4/quantize_nvfp4_1x64_rowwise.cu delete mode 100644 transformer_engine/common/cast/nvfp4/quantize_nvfp4_1x64_rowwise.cuh diff --git a/tests/pytorch/nvfp4/test_nvfp4_1x64_quantize_exact.py b/tests/pytorch/nvfp4/test_nvfp4_1x64_quantize_exact.py index 694382970f..324f9abd4e 100644 --- a/tests/pytorch/nvfp4/test_nvfp4_1x64_quantize_exact.py +++ b/tests/pytorch/nvfp4/test_nvfp4_1x64_quantize_exact.py @@ -2,13 +2,23 @@ # # See LICENSE for license information. -"""Bit-exact tests for the NVFP4 rowwise 1x64 local-encode CUDA kernel. +"""Bit-exact tests for the NVFP4 hierarchical 1x64 cast CUDA kernel. Methodology mirrors ``test_nvfp4_quantize_exact.py``: invoke the kernel through ``NVFP4Quantizer`` (with the ``NVTE_NVFP4_ROWWISE_1X64_LOCAL_ENCODE=1`` -env flag enabling the alternate dispatch), and compare the resulting -``(qx, sx, amax)`` triple against the pure-PyTorch oracle -``NVFP4Quantizer1x64Ref`` byte-for-byte (``atol=rtol=0``). +env flag enabling the alternate dispatch), and compare every output buffer +against a pure-PyTorch oracle (``NVFP4Quantizer1x64Ref``) byte-for-byte +(``atol=rtol=0``). + +The kernel writes up to six tensors per call: + +* rowwise FP4 data, rowwise per-1x16 E4M3 ``s_dec``, rowwise per-1x64 + window amax; +* columnwise (transposed) FP4 data, columnwise per-1x16 E4M3 ``s_dec``, + columnwise per-1x64 window amax. + +This file covers the rowwise-only, columnwise-only, and rowwise+columnwise +configurations -- the latter being the production-equivalent fused mode. """ import pytest @@ -26,169 +36,200 @@ def unpack_fp4(x: torch.Tensor) -> torch.Tensor: - """Unpack two FP4 values per byte into one ``uint8`` value per element. - - Identical to the helper in ``test_nvfp4_quantize_exact.py`` -- duplicated - here so the two test suites stay independent. - """ + """Unpack two FP4 values per byte into one ``uint8`` value per element.""" repeated = x.repeat_interleave(2, dim=1) repeated[:, 0::2] &= 0x0F repeated[:, 1::2] >>= 4 return repeated -def _check_quantization_1x64_versus_reference( - x_dtype: torch.dtype, - M: int, - N: int, +def _check_quantization_1x64_versus_reference_with_input( + x: torch.Tensor, + *, + rowwise: bool, + columnwise: bool, ) -> None: - """Quantize ``(M, N)`` random input through both kernel and reference and - assert the rowwise data, scale, and global amax all match bit-exactly.""" + """Quantize ``x`` through both kernel and reference and assert that every + requested output (data, scale, window-amax, on each requested direction) + matches bit-exactly. + """ te_dtype = tex.DType.kFloat4E2M1 - device = "cuda" - - seed = 0 - torch.manual_seed(seed) - torch.cuda.manual_seed(seed) - x = torch.randn((M, N), dtype=x_dtype, device=device) - # Kernel path. ``with_rht=False`` and ``with_2d_quantization=False`` are - # required by the 1x64 dispatch's preconditions; ``columnwise=False`` - # because the kernel does not produce a transposed output. quantizer = NVFP4Quantizer( fp4_dtype=te_dtype, - rowwise=True, - columnwise=False, + rowwise=rowwise, + columnwise=columnwise, with_amax_reduction=False, amax_reduction_group=None, with_rht=False, with_post_rht_amax=False, with_2d_quantization=False, ) - x_nvfp4_sut = quantizer(x) + sut = quantizer(x) + ref = NVFP4Quantizer1x64Ref(rowwise=rowwise, columnwise=columnwise).quantize(x) + + if rowwise: + qx_sut = unpack_fp4(sut._rowwise_data.view(dtype=torch.uint8)) + qx_ref = unpack_fp4(ref.data.view(dtype=torch.uint8)) + sx_sut = sut._rowwise_scale_inv.view(dtype=torch.uint8) + sx_ref = ref.scale.view(dtype=torch.uint8) + # The kernel may pad qx/sx to alignment boundaries; only the + # un-padded prefix is meaningfully written. + torch.testing.assert_close( + qx_sut[: qx_ref.shape[0], : qx_ref.shape[1]], qx_ref, atol=0.0, rtol=0.0 + ) + torch.testing.assert_close( + sx_sut[: sx_ref.shape[0], : sx_ref.shape[1]], sx_ref, atol=0.0, rtol=0.0 + ) + torch.testing.assert_close(sut._amax_rowwise, ref.window_amax_row, atol=0.0, rtol=0.0) + + if columnwise: + qxt_sut = unpack_fp4(sut._columnwise_data.view(dtype=torch.uint8)) + qxt_ref = unpack_fp4(ref.columnwise_data.view(dtype=torch.uint8)) + sxt_sut = sut._columnwise_scale_inv.view(dtype=torch.uint8) + sxt_ref = ref.columnwise_scale.view(dtype=torch.uint8) + torch.testing.assert_close( + qxt_sut[: qxt_ref.shape[0], : qxt_ref.shape[1]], qxt_ref, atol=0.0, rtol=0.0 + ) + torch.testing.assert_close( + sxt_sut[: sxt_ref.shape[0], : sxt_ref.shape[1]], sxt_ref, atol=0.0, rtol=0.0 + ) + torch.testing.assert_close( + sut._amax_columnwise, ref.window_amax_col, atol=0.0, rtol=0.0 + ) + + +def _check_random( + x_dtype: torch.dtype, + M: int, + N: int, + *, + rowwise: bool, + columnwise: bool, +) -> None: + """Random-input variant. Seeds are fixed so failures reproduce.""" + device = "cuda" + seed = 0 + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + x = torch.randn((M, N), dtype=x_dtype, device=device) + _check_quantization_1x64_versus_reference_with_input( + x, rowwise=rowwise, columnwise=columnwise + ) - qx = x_nvfp4_sut._rowwise_data.view(dtype=torch.uint8) - sx = x_nvfp4_sut._rowwise_scale_inv - qx_amax = x_nvfp4_sut._amax_rowwise - # Reference path. - ref = NVFP4Quantizer1x64Ref().quantize(x) - qx_ref = unpack_fp4(ref.data.view(dtype=torch.uint8)) - sx_ref = ref.scale.view(dtype=torch.uint8) - ref_amax = ref.global_amax_row +# Shapes where both M and N are multiples of 64 -- the 1x64 hierarchy's +# strict alignment requirement (enforced by the dispatcher). +_SHAPES_64x64_MULTIPLE = [ + (64, 64), + (128, 128), + (256, 256), + (256, 512), + (1024, 256), + (256, 1024), + (2048, 2048), + (1024, 2048), + # Non-square shapes that exercise distinct row/col tile counts -- the + # columnwise pass uses M/64 windows, the rowwise uses N/64. + (64, 256), + (256, 64), + (192, 384), + (384, 192), +] - qx = unpack_fp4(qx) - # The kernel may pad qx/sx to alignment boundaries; only the unpadded - # prefix is meaningfully written. Slice both sides down to the reference - # shape before comparing (the reference returns un-padded tensors). - qx_valid = qx[: qx_ref.shape[0], : qx_ref.shape[1]] - sx_valid = sx[: sx_ref.shape[0], : sx_ref.shape[1]] +@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe) +@pytest.mark.parametrize("M, N", _SHAPES_64x64_MULTIPLE) +@pytest.mark.parametrize("x_dtype", [torch.float32, torch.bfloat16], ids=str) +def test_nvfp4_1x64_quantize_rowwise( + monkeypatch, x_dtype: torch.dtype, M: int, N: int +) -> None: + """Rowwise-only configuration -- preserves the original PR's coverage.""" + monkeypatch.setenv("NVTE_NVFP4_ROWWISE_1X64_LOCAL_ENCODE", "1") + monkeypatch.setenv("NVTE_NVFP4_DISABLE_RHT", "1") + _check_random(x_dtype, M, N, rowwise=True, columnwise=False) + - torch.testing.assert_close(qx_valid, qx_ref, atol=0.0, rtol=0.0) - torch.testing.assert_close(sx_valid, sx_ref, atol=0.0, rtol=0.0) - torch.testing.assert_close(qx_amax, ref_amax, atol=0.0, rtol=0.0) +@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe) +@pytest.mark.parametrize("M, N", _SHAPES_64x64_MULTIPLE) +@pytest.mark.parametrize("x_dtype", [torch.float32, torch.bfloat16], ids=str) +def test_nvfp4_1x64_quantize_columnwise( + monkeypatch, x_dtype: torch.dtype, M: int, N: int +) -> None: + """Columnwise-only -- exercises the transposed output path on its own.""" + monkeypatch.setenv("NVTE_NVFP4_ROWWISE_1X64_LOCAL_ENCODE", "1") + monkeypatch.setenv("NVTE_NVFP4_DISABLE_RHT", "1") + _check_random(x_dtype, M, N, rowwise=False, columnwise=True) @pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe) -@pytest.mark.parametrize( - "M, N", - [ - # K is a multiple of WINDOW_K=64 - (128, 128), - (256, 256), - (256, 512), - (1024, 256), - (256, 1024), - (2048, 2048), - (1024, 2048), - # K is a multiple of BLOCK_K=16 but not of WINDOW_K -- exercises the - # partial-last-window path in the kernel. - (256, 80), - (256, 272), - (256, 336), - ], -) +@pytest.mark.parametrize("M, N", _SHAPES_64x64_MULTIPLE) @pytest.mark.parametrize("x_dtype", [torch.float32, torch.bfloat16], ids=str) -def test_nvfp4_1x64_quantize_versus_reference( - monkeypatch, - x_dtype: torch.dtype, - M: int, - N: int, +def test_nvfp4_1x64_quantize_rowwise_columnwise( + monkeypatch, x_dtype: torch.dtype, M: int, N: int ) -> None: - """Random-input bit-exact test across a representative shape grid.""" + """Fused rowwise+columnwise -- the production-equivalent configuration.""" monkeypatch.setenv("NVTE_NVFP4_ROWWISE_1X64_LOCAL_ENCODE", "1") monkeypatch.setenv("NVTE_NVFP4_DISABLE_RHT", "1") - _check_quantization_1x64_versus_reference(x_dtype, M, N) + _check_random(x_dtype, M, N, rowwise=True, columnwise=True) @pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe) @pytest.mark.parametrize("x_dtype", [torch.float32, torch.bfloat16], ids=str) -def test_nvfp4_1x64_quantize_extrema(monkeypatch, x_dtype: torch.dtype) -> None: - """Stress the saturating-cast and ``S_enc`` clamps with extreme inputs. +@pytest.mark.parametrize( + "rowwise, columnwise", + [(True, False), (False, True), (True, True)], + ids=["row", "col", "rowcol"], +) +def test_nvfp4_1x64_quantize_extrema( + monkeypatch, x_dtype: torch.dtype, rowwise: bool, columnwise: bool +) -> None: + """Stress the saturating cast and ``S_enc`` clamps on each direction. Inputs: - * an outlier-heavy row (one ~FP4_MAX value per 1x64 window plus tiny - noise) -- exercises the ``s_dec`` saturation branch; + * an outlier-heavy row (one ~FP4_MAX value per 1x64 K-window plus tiny + noise) -- exercises the rowwise ``s_dec`` saturation branch; + * an outlier-heavy column (one ~FP4_MAX value per 1x64 M-window plus + tiny noise) -- the columnwise mirror of the above; * an all-zero region -- exercises the ``tile_amax == 0`` fallback that - promotes ``S_enc`` to 1.0; - * a uniform constant row -- ``s_dec`` should be the same E4M3 byte for - every block in the row. + promotes ``S_enc`` to 1.0 and the kernel's ``s_dec == 0`` + short-circuit (without which ``cvt.rn.satfinite.e2m1x4.f32(NaN)`` + on SM10 would saturate to ``+FP4_MAX = 0x7`` instead of ``0x0``); + * a uniform constant block -- ``s_dec`` should be the same E4M3 byte + for every block in the affected window. """ monkeypatch.setenv("NVTE_NVFP4_ROWWISE_1X64_LOCAL_ENCODE", "1") monkeypatch.setenv("NVTE_NVFP4_DISABLE_RHT", "1") device = "cuda" - M, N = 64, 256 + M, N = 128, 256 torch.manual_seed(0) torch.cuda.manual_seed(0) x = torch.randn((M, N), dtype=x_dtype, device=device) * 0.01 - # Row 0: per-window outliers. + # Row 0: per-1x64-K-window outlier on the rowwise side. for w in range(N // 64): x[0, w * 64] = 5.5 - # Row 1: all zero (degenerate window amax path). - x[1] = 0.0 - - # Row 2: uniform constant. - x[2] = 0.75 + # Col 0: per-1x64-M-window outlier on the columnwise side. Distinct row + # from row 0 so the two outlier patterns do not overlap. + for w in range(M // 64): + x[w * 64 + 1, 0] = 5.25 # row index != 0, col 0 - _check_quantization_1x64_versus_reference_with_input(x) + # Row 1: all zero (degenerate window amax path on the rowwise side). + # We've already touched (1, 0) above so re-zero just (1, 1:) -- the + # block at (1, 0..15) still contains a single non-zero outlier 5.25, + # which is fine; the all-zero stress remains in (1, 16:). + x[1, 16:] = 0.0 + # Col 1: all zero (degenerate window amax path on the columnwise side). + x[16:, 1] = 0.0 -def _check_quantization_1x64_versus_reference_with_input(x: torch.Tensor) -> None: - """Variant of the random-input checker that consumes a caller-provided - tensor (used by the extrema test).""" - te_dtype = tex.DType.kFloat4E2M1 + # Row 2: uniform constant -- exercises both directions' constant-block + # path simultaneously for that row's columns it spans. + x[2] = 0.75 - quantizer = NVFP4Quantizer( - fp4_dtype=te_dtype, - rowwise=True, - columnwise=False, - with_amax_reduction=False, - amax_reduction_group=None, - with_rht=False, - with_post_rht_amax=False, - with_2d_quantization=False, + _check_quantization_1x64_versus_reference_with_input( + x, rowwise=rowwise, columnwise=columnwise ) - x_nvfp4_sut = quantizer(x) - - qx = x_nvfp4_sut._rowwise_data.view(dtype=torch.uint8) - sx = x_nvfp4_sut._rowwise_scale_inv - qx_amax = x_nvfp4_sut._amax_rowwise - - ref = NVFP4Quantizer1x64Ref().quantize(x) - qx_ref = unpack_fp4(ref.data.view(dtype=torch.uint8)) - sx_ref = ref.scale.view(dtype=torch.uint8) - ref_amax = ref.global_amax_row - - qx = unpack_fp4(qx) - - qx_valid = qx[: qx_ref.shape[0], : qx_ref.shape[1]] - sx_valid = sx[: sx_ref.shape[0], : sx_ref.shape[1]] - - torch.testing.assert_close(qx_valid, qx_ref, atol=0.0, rtol=0.0) - torch.testing.assert_close(sx_valid, sx_ref, atol=0.0, rtol=0.0) - torch.testing.assert_close(qx_amax, ref_amax, atol=0.0, rtol=0.0) diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index bf17d5556f..034c113f70 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -219,7 +219,7 @@ list(APPEND transformer_engine_cuda_arch_specific_sources multi_tensor/compute_scale.cu recipe/mxfp8_scaling.cu recipe/nvfp4.cu - cast/nvfp4/quantize_nvfp4_1x64_rowwise.cu + cast/nvfp4/quantize_nvfp4_1x64.cu transpose/quantize_transpose_square_blockwise.cu transpose/quantize_transpose_vector_blockwise_fp4.cu) diff --git a/transformer_engine/common/cast/dispatch/quantize.cuh b/transformer_engine/common/cast/dispatch/quantize.cuh index dccf5a2091..969e0b09ef 100644 --- a/transformer_engine/common/cast/dispatch/quantize.cuh +++ b/transformer_engine/common/cast/dispatch/quantize.cuh @@ -22,7 +22,7 @@ #include "../mxfp8/quantize_mxfp8.cuh" #include "../nvfp4/group_quantize_transpose_nvfp4.cuh" #include "../nvfp4/quantize_nvfp4.cuh" -#include "../nvfp4/quantize_nvfp4_1x64_rowwise.cuh" +#include "../nvfp4/quantize_nvfp4_1x64.cuh" #include "../nvfp4/quantize_transpose_nvfp4.cuh" namespace transformer_engine { @@ -105,16 +105,18 @@ void quantize_fwd_helper(const NVTETensor input, NVTETensor output, bool use_optimized_kernel = (dtype == DType::kBFloat16) && (rows % 32 == 0) && (cols % 32 == 0) && output_tensor->has_data(); - // Per-1x64-K-tile S_enc (non-RHT rowwise only). Fused window kernel; no columnwise / no GEMM. + // Per-1x64-K-window S_enc (non-RHT). The kernel produces rowwise and/or + // columnwise outputs in a single fused tile pass; per-window amax is + // written into the existing ``amax`` / ``columnwise_amax`` slots, which + // the PyTorch wrapper allocates with shape (M, N/64) / (N, M/64) when + // ``nvfp4_1x64_local_encode`` is set. if (quant_config_cpp.nvfp4_rowwise_1x64_local_encode) { - NVTE_CHECK(!output_tensor->has_columnwise_data(), - "NVFP4 rowwise 1x64 local encode does not support columnwise (transposed) output."); NVTE_CHECK(!quant_config_cpp.nvfp4_2d_quantization, - "NVFP4 rowwise 1x64 local encode is incompatible with 2D block scaling."); + "NVFP4 1x64 local encode is incompatible with 2D block scaling."); NVTE_CHECK(!quant_config_cpp.stochastic_rounding, - "NVFP4 rowwise 1x64 local encode does not support stochastic rounding yet."); - nvfp4::quantize_rowwise_1x64_local_encode(*input_tensor, *noop_tensor, output_tensor, - quant_config_cpp, stream); + "NVFP4 1x64 local encode does not support stochastic rounding yet."); + nvfp4::quantize_1x64_local_encode(*input_tensor, *noop_tensor, output_tensor, + quant_config_cpp, stream); } else if (use_optimized_kernel) { if (quant_config_cpp.nvfp4_2d_quantization) { nvfp4::quantize_transpose( diff --git a/transformer_engine/common/cast/nvfp4/quantize_nvfp4_1x64.cu b/transformer_engine/common/cast/nvfp4/quantize_nvfp4_1x64.cu new file mode 100644 index 0000000000..becb07a2ea --- /dev/null +++ b/transformer_engine/common/cast/nvfp4/quantize_nvfp4_1x64.cu @@ -0,0 +1,297 @@ +/************************************************************************* + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include "common/common.h" +#include "common/cast/nvfp4/quantize_nvfp4_1x64.cuh" +#include "common/cast/nvfp4/core_nvfp4.cuh" +#include "common/util/ptx.cuh" +#include "common/utils.cuh" + +namespace transformer_engine { +namespace dispatch { +namespace nvfp4 { +namespace { + +#if FP4_TYPE_SUPPORTED + +using ptx::FPx2; +using quantization_SF::compute_decoding_scaling_factor; +using core::compute_global_encode_scaling_factor_FP4; + +// One CUDA block = one 64x64 input tile in (M, N) row-major space. +// +// Grid layout: (blockIdx.x, blockIdx.y) = (n_window, m_tile), so each CTA +// owns exactly one 1x64 K-window for the rowwise pass *and* one 1x64 +// (transposed-K) M-window for the columnwise pass. With 64 threads per +// CTA, threadIdx.x doubles as "row index in tile" during the rowwise pass +// and "column index in tile" during the columnwise pass. +// +// Either pass can be skipped at runtime by passing nullptr for its three +// output buffers; the SMEM tile load is shared. +// +// SMEM is padded to ``[64][65]`` so the columnwise transpose access +// (``in_sm[e][tid]`` walking down a column) does not fall on the same +// 32-bank lane for every row. +template +__global__ void __launch_bounds__(64) nvfp4_1x64_fused_per_tile( + const IType* __restrict__ in, const size_t rows, const size_t cols, const int ld_row_elts, + // Rowwise outputs (all three are non-null together, or all null). + uint8_t* __restrict__ q_row, fp8e4m3* __restrict__ s_dec_row, + float* __restrict__ w_amax_row, const size_t s_dec_row_stride, + const size_t w_amax_row_stride, + // Columnwise (transposed) outputs (all three together, or all null). + uint8_t* __restrict__ q_col, fp8e4m3* __restrict__ s_dec_col, + float* __restrict__ w_amax_col, const size_t s_dec_col_stride, + const size_t w_amax_col_stride, + const float* __restrict__ noop) { +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) + if (noop != nullptr && noop[0] == 1.0f) { + return; + } + + const int tile_n = static_cast(blockIdx.x); + const int tile_m = static_cast(blockIdx.y); + const int tid = static_cast(threadIdx.x); // 0..63 + + const int row_base = tile_m * 64; + const int col_base = tile_n * 64; + if (row_base >= static_cast(rows) || col_base >= static_cast(cols)) { + return; + } + + const bool do_row = (q_row != nullptr); + const bool do_col = (q_col != nullptr); + + // 64x64 fp32 staging buffer with +1 column padding to side-step bank + // conflicts on the columnwise transpose access pattern. + __shared__ float in_sm[64][65]; + + // Cooperative load: thread tid loads its assigned row (64 elements). + // Padding with zeros keeps the amax reductions correct on M-tail rows + // (the dispatcher already guarantees full 64-aligned tiles, so this is + // strictly defensive). + { + const int gr = row_base + tid; + if (gr < static_cast(rows)) { +#pragma unroll + for (int e = 0; e < 64; e++) { + const int gc = col_base + e; + in_sm[tid][e] = + static_cast(in[static_cast(gr) * ld_row_elts + gc]); + } + } else { +#pragma unroll + for (int e = 0; e < 64; e++) { + in_sm[tid][e] = 0.f; + } + } + } + __syncthreads(); + + using IType2 = FPx2; + + // ============================ ROWWISE PASS ============================ + if (do_row && (row_base + tid) < static_cast(rows)) { + const int r = row_base + tid; + + float wmx = 0.f; +#pragma unroll + for (int e = 0; e < 64; e++) { + wmx = fmaxf(wmx, fabsf(in_sm[tid][e])); + } + const float S_enc = compute_global_encode_scaling_factor_FP4(fmaxf(wmx, 1e-12f)); + + w_amax_row[static_cast(r) * w_amax_row_stride + tile_n] = wmx; + + uint8_t* row_out = q_row + static_cast(r) * (cols / 2); +#pragma unroll + for (int b = 0; b < 4; b++) { + float bmx = 0.f; + float vals[16]; +#pragma unroll + for (int e = 0; e < 16; e++) { + const float v = in_sm[tid][b * 16 + e]; + vals[e] = v; + bmx = fmaxf(bmx, fabsf(v)); + } + const fp8e4m3 s_dec = compute_decoding_scaling_factor(bmx, S_enc); + const float s_dec_f = static_cast(s_dec); + // Match the reference's all-zero-block branch (see + // ``NVFP4Quantizer1x64Ref``): when ``bmx == 0`` ``s_dec`` saturates + // to 0 and a naive ``S_enc / 0`` would NaN through the cvt. + const float block_scale = (s_dec_f == 0.f) ? 0.f : __fdiv_rn(S_enc, s_dec_f); + + const int sub_blk_global = (col_base + b * 16) / 16; + s_dec_row[static_cast(r) * s_dec_row_stride + sub_blk_global] = s_dec; + + const size_t byte_off = static_cast(col_base + b * 16) / 2; +#pragma unroll + for (int q = 0; q < 4; q++) { + const int e0 = q * 4; + IType2 in01{static_cast(vals[e0]), static_cast(vals[e0 + 1])}; + IType2 in23{static_cast(vals[e0 + 2]), static_cast(vals[e0 + 3])}; + fp4e2m1x4 qu{}; + ptx::mul_cvt_4x(qu, in01, in23, block_scale); + *reinterpret_cast(row_out + byte_off + static_cast(2 * q)) = qu; + } + } + } + + // No barrier between rowwise and colwise passes: ``in_sm`` is read-only + // after the initial load+sync, all writes go to disjoint global memory + // regions, so the two passes are safely concurrent at warp granularity. + + // =========================== COLUMNWISE PASS =========================== + if (do_col && (col_base + tid) < static_cast(cols)) { + const int c = col_base + tid; + + float wmx = 0.f; +#pragma unroll + for (int e = 0; e < 64; e++) { + wmx = fmaxf(wmx, fabsf(in_sm[e][tid])); + } + const float S_enc = compute_global_encode_scaling_factor_FP4(fmaxf(wmx, 1e-12f)); + + w_amax_col[static_cast(c) * w_amax_col_stride + tile_m] = wmx; + + // Transposed output: q_col is laid out as (N, M/2). Each "row" of the + // transposed tensor corresponds to one column of the input, so byte + // stride is rows/2 along original-M. + uint8_t* col_out = q_col + static_cast(c) * (rows / 2); +#pragma unroll + for (int b = 0; b < 4; b++) { + float bmx = 0.f; + float vals[16]; +#pragma unroll + for (int e = 0; e < 16; e++) { + const float v = in_sm[b * 16 + e][tid]; + vals[e] = v; + bmx = fmaxf(bmx, fabsf(v)); + } + const fp8e4m3 s_dec = compute_decoding_scaling_factor(bmx, S_enc); + const float s_dec_f = static_cast(s_dec); + const float block_scale = (s_dec_f == 0.f) ? 0.f : __fdiv_rn(S_enc, s_dec_f); + + const int sub_blk_global = (row_base + b * 16) / 16; + s_dec_col[static_cast(c) * s_dec_col_stride + sub_blk_global] = s_dec; + + const size_t byte_off = static_cast(row_base + b * 16) / 2; +#pragma unroll + for (int q = 0; q < 4; q++) { + const int e0 = q * 4; + IType2 in01{static_cast(vals[e0]), static_cast(vals[e0 + 1])}; + IType2 in23{static_cast(vals[e0 + 2]), static_cast(vals[e0 + 3])}; + fp4e2m1x4 qu{}; + ptx::mul_cvt_4x(qu, in01, in23, block_scale); + *reinterpret_cast(col_out + byte_off + static_cast(2 * q)) = qu; + } + } + } +#endif // __CUDA_ARCH__ >= 1000 +} + +#endif // FP4_TYPE_SUPPORTED + +} // namespace + +void quantize_1x64_local_encode(const Tensor& input, const Tensor& noop, Tensor* output, + const QuantizationConfig& /* quant_config */, cudaStream_t stream) { +#if FP4_TYPE_SUPPORTED + CheckNoopTensor(noop, "cast_noop"); + NVTE_CHECK(input.has_data(), "NVFP4 1x64: input has no data."); + NVTE_CHECK(output->has_data() || output->has_columnwise_data(), + "NVFP4 1x64: at least one of rowwise/columnwise output must be allocated."); + NVTE_CHECK(!output->with_gemm_swizzled_scales, + "NVFP4 1x64: expects compact (non-gemm) scales on both directions."); + + const size_t rows = input.flat_first_dim(); + const size_t cols = input.flat_last_dim(); + if (rows == 0 || cols == 0) { + return; + } + // 1x16 sub-block inside a 1x64 K-window: both dimensions must be multiples + // of 64 to keep all four sub-blocks of every window in-bounds for both + // directions. The PyTorch wrapper (NVFP4Quantizer::create_tensor) sizes the + // ``amax`` slot accordingly, so this check also pins down the Python side. + NVTE_CHECK(cols % 64 == 0, "NVFP4 1x64: K (cols) must be a multiple of 64, got: ", cols); + NVTE_CHECK(rows % 64 == 0, "NVFP4 1x64: M (rows) must be a multiple of 64, got: ", rows); + + uint8_t* q_row = nullptr; + fp8e4m3* s_dec_row = nullptr; + float* w_amax_row = nullptr; + size_t s_dec_row_stride = 0; + size_t w_amax_row_stride = 0; + if (output->has_data()) { + NVTE_CHECK(output->scale_inv.dptr != nullptr, + "NVFP4 1x64: rowwise scale_inv must be allocated."); + NVTE_CHECK(output->amax.dptr != nullptr, + "NVFP4 1x64: rowwise amax (per-window) buffer is required."); + q_row = reinterpret_cast(output->data.dptr); + s_dec_row = reinterpret_cast(output->scale_inv.dptr); + w_amax_row = reinterpret_cast(output->amax.dptr); + s_dec_row_stride = output->scale_inv.shape.size() > 1 ? output->scale_inv.shape[1] : 1; + w_amax_row_stride = output->amax.shape.size() > 1 ? output->amax.shape[1] : 1; + NVTE_CHECK(s_dec_row_stride == cols / 16, + "NVFP4 1x64: rowwise scale_inv stride must equal cols/16, got ", + s_dec_row_stride); + NVTE_CHECK(w_amax_row_stride == cols / 64, + "NVFP4 1x64: rowwise amax stride must equal cols/64, got ", w_amax_row_stride); + } + + uint8_t* q_col = nullptr; + fp8e4m3* s_dec_col = nullptr; + float* w_amax_col = nullptr; + size_t s_dec_col_stride = 0; + size_t w_amax_col_stride = 0; + if (output->has_columnwise_data()) { + NVTE_CHECK(output->columnwise_scale_inv.dptr != nullptr, + "NVFP4 1x64: columnwise scale_inv must be allocated."); + NVTE_CHECK(output->columnwise_amax.dptr != nullptr, + "NVFP4 1x64: columnwise amax (per-window) buffer is required."); + q_col = reinterpret_cast(output->columnwise_data.dptr); + s_dec_col = reinterpret_cast(output->columnwise_scale_inv.dptr); + w_amax_col = reinterpret_cast(output->columnwise_amax.dptr); + s_dec_col_stride = output->columnwise_scale_inv.shape.size() > 1 + ? output->columnwise_scale_inv.shape[1] + : 1; + w_amax_col_stride = output->columnwise_amax.shape.size() > 1 + ? output->columnwise_amax.shape[1] + : 1; + NVTE_CHECK(s_dec_col_stride == rows / 16, + "NVFP4 1x64: columnwise scale_inv stride must equal rows/16, got ", + s_dec_col_stride); + NVTE_CHECK(w_amax_col_stride == rows / 64, + "NVFP4 1x64: columnwise amax stride must equal rows/64, got ", w_amax_col_stride); + } + + const size_t n_win = cols / 64; + const size_t m_tiles = rows / 64; + dim3 grid(static_cast(n_win), static_cast(m_tiles), 1); + constexpr int kBlock = 64; + + TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( + input.dtype(), IType, { + const IType* in_t = reinterpret_cast(input.data.dptr); + nvfp4_1x64_fused_per_tile<<>>( + in_t, rows, cols, static_cast(cols), q_row, s_dec_row, w_amax_row, + s_dec_row_stride, w_amax_row_stride, q_col, s_dec_col, w_amax_col, + s_dec_col_stride, w_amax_col_stride, + reinterpret_cast(noop.data.dptr)); + NVTE_CHECK_CUDA(cudaGetLastError()); + }); + +#else + (void)input; + (void)noop; + (void)output; + (void)stream; + NVTE_ERROR("FP4 support requires CUDA 12.8+, but compile-time CUDA version is ", CUDA_VERSION); +#endif +} + +} // namespace nvfp4 +} // namespace dispatch +} // namespace transformer_engine diff --git a/transformer_engine/common/cast/nvfp4/quantize_nvfp4_1x64.cuh b/transformer_engine/common/cast/nvfp4/quantize_nvfp4_1x64.cuh new file mode 100644 index 0000000000..ddc2a165b8 --- /dev/null +++ b/transformer_engine/common/cast/nvfp4/quantize_nvfp4_1x64.cuh @@ -0,0 +1,45 @@ +/************************************************************************* + * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/*! \file quantize_nvfp4_1x64.cuh + * \brief NVFP4 hierarchical 1x64 cast (rowwise + optional transposed columnwise), + * with per-1x64-K-window S_enc and per-1x16 sub-block E4M3 s_dec. + * + * The kernel produces, for an (M, N) input, any non-empty subset of: + * * rowwise NVFP4 data + 1x16 E4M3 scales + (M, N/64) FP32 window amax + * * columnwise (transposed) NVFP4 data + 1x16 E4M3 scales + (N, M/64) FP32 window amax + * + * The "window amax" tensors are stored on the existing ``amax`` / + * ``columnwise_amax`` slots of the C++ ``Tensor`` (their shape is upgraded + * from a scalar (1,) -- as used by the per-tensor NVFP4 path -- to a 2D + * per-window buffer in 1x64 mode). Consumers that need the global tensor + * amax can take ``max`` over the per-window buffer at trivial cost. + * + * Non-RHT, non-2D, non-stochastic-rounding only. Both M and N are + * required to be multiples of 64 by the dispatcher. + */ +#ifndef TRANSFORMER_ENGINE_QUANTIZE_NVFP4_1X64_CUH_ +#define TRANSFORMER_ENGINE_QUANTIZE_NVFP4_1X64_CUH_ + +#include + +#include "../../common.h" + +namespace transformer_engine { +namespace dispatch { +namespace nvfp4 { + +// Hierarchical 1x64 + 1x16 NVFP4 cast. Routes to the fused rowwise+columnwise +// kernel; populates whichever of ``data`` / ``columnwise_data`` are present +// on ``output`` (and the matching scales + window amax buffers). +void quantize_1x64_local_encode(const Tensor& input, const Tensor& noop, Tensor* output, + const QuantizationConfig& quant_config, cudaStream_t stream); + +} // namespace nvfp4 +} // namespace dispatch +} // namespace transformer_engine + +#endif // TRANSFORMER_ENGINE_QUANTIZE_NVFP4_1X64_CUH_ diff --git a/transformer_engine/common/cast/nvfp4/quantize_nvfp4_1x64_rowwise.cu b/transformer_engine/common/cast/nvfp4/quantize_nvfp4_1x64_rowwise.cu deleted file mode 100644 index 70985a4965..0000000000 --- a/transformer_engine/common/cast/nvfp4/quantize_nvfp4_1x64_rowwise.cu +++ /dev/null @@ -1,168 +0,0 @@ -/************************************************************************* - * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * - * See LICENSE for license information. - ************************************************************************/ - -#include "common/common.h" -#include "common/cast/nvfp4/quantize_nvfp4_1x64_rowwise.cuh" -#include "common/cast/nvfp4/core_nvfp4.cuh" -#include "common/util/ptx.cuh" -#include "common/utils.cuh" - -namespace transformer_engine { -namespace dispatch { -namespace nvfp4 { -namespace { - -#if FP4_TYPE_SUPPORTED - -using ptx::FPx2; -using quantization_SF::compute_decoding_scaling_factor; -using core::compute_global_encode_scaling_factor_FP4; - -// One CUDA block = one 1x64 K-tile in (row, K-window) layout. -// Threads reduce |x| over the tile, then S_enc = TE global formula on tile amax; -// 1x16 blocks share that S_enc. -template -__global__ void __launch_bounds__(64) nvfp4_rowwise_1x64_per_tile( - const IType* __restrict__ in, const size_t rows, const size_t cols, const int ld_row_elts, - uint8_t* __restrict__ out_data, // raw fp4 bytes (same layout as other NVFP4 rowwise) - fp8e4m3* __restrict__ row_scales, const size_t scale_stride, float* __restrict__ amax_global, - const float* __restrict__ noop) { -#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) - if (noop != nullptr && noop[0] == 1.0f) { - return; - } - - const int w = static_cast(blockIdx.x); - const int r = static_cast(blockIdx.y); - const int c0 = w * 64; - if (r >= static_cast(rows) || c0 >= static_cast(cols)) { - return; - } - - const int win_len = min(64, static_cast(cols) - c0); - - __shared__ float sm[64]; - for (int i = threadIdx.x; i < 64; i += blockDim.x) { - if (i < win_len) { - sm[i] = fabsf(static_cast(in[static_cast(r) * ld_row_elts + c0 + i])); - } else { - sm[i] = 0.f; - } - } - __syncthreads(); - - if (threadIdx.x == 0) { - float wmx = 0.f; - for (int i = 0; i < 64; i++) { - wmx = fmaxf(wmx, sm[i]); - } - sm[0] = wmx; - } - __syncthreads(); - - const float tile_amax = sm[0]; - const float S_enc_tile = compute_global_encode_scaling_factor_FP4(fmaxf(tile_amax, 1e-12f)); - - if (amax_global != nullptr) { - atomicMaxFloat(amax_global, tile_amax); - } - - if (threadIdx.x < 4) { - const int b = static_cast(threadIdx.x); - const int cs = c0 + b * 16; - if (b * 16 < win_len && cs + 16 <= static_cast(cols)) { - float bmx = 0.f; - float vals[16]; - for (int e = 0; e < 16; e++) { - const float v = static_cast(in[static_cast(r) * ld_row_elts + cs + e]); - vals[e] = v; - bmx = fmaxf(bmx, fabsf(v)); - } - const fp8e4m3 s_dec = compute_decoding_scaling_factor(bmx, S_enc_tile); - // Match the reference's all-zero-block branch: when ``bmx == 0`` the - // FP8 cast saturates ``s_dec`` to 0, so the naive ``S_enc / s_dec`` - // would yield ``+inf`` and the subsequent ``0 * +inf`` a NaN that - // ``cvt.rn.satfinite.e2m1x4.f32`` resolves to ``+FP4_MAX`` (0x7) on - // SM10, not 0x0. Short-circuiting here keeps the kernel byte-exact - // with ``NVFP4Quantizer1x64Ref``. - const float s_dec_f = static_cast(s_dec); - const float block_scale = (s_dec_f == 0.f) ? 0.f : __fdiv_rn(S_enc_tile, s_dec_f); - - const int c16 = cs / 16; - row_scales[static_cast(r) * scale_stride + static_cast(c16)] = s_dec; - - using IType2 = FPx2; - const size_t row_bytes = static_cast(cols) / 2; // fp4 packed: cols/2 bytes per row - uint8_t* row_out = out_data + static_cast(r) * row_bytes; - for (int q = 0; q < 4; q++) { - const int e0 = q * 4; - IType2 in01{static_cast(vals[e0]), static_cast(vals[e0 + 1])}; - IType2 in23{static_cast(vals[e0 + 2]), static_cast(vals[e0 + 3])}; - fp4e2m1x4 qu{}; - ptx::mul_cvt_4x(qu, in01, in23, block_scale); - *reinterpret_cast(row_out + static_cast(cs / 2) + static_cast(2 * q)) = - qu; - } - } - } -#endif -} - -#endif // FP4_TYPE_SUPPORTED - -} // namespace - -void quantize_rowwise_1x64_local_encode(const Tensor& input, const Tensor& noop, Tensor* output, - const QuantizationConfig& /* quant_config */, cudaStream_t stream) { -#if FP4_TYPE_SUPPORTED - CheckNoopTensor(noop, "cast_noop"); - NVTE_CHECK(input.has_data(), "NVFP4 1x64: input has no data."); - NVTE_CHECK(output->has_data(), "NVFP4 1x64: output has no data."); - NVTE_CHECK(!output->has_columnwise_data(), - "NVFP4 rowwise 1x64: columnwise (transpose) path is not supported (no RHT / no GEMM)."); - NVTE_CHECK(output->scale_inv.dptr != nullptr, "NVFP4 1x64: rowwise scale_inv must be allocated."); - NVTE_CHECK(!output->with_gemm_swizzled_scales, "NVFP4 1x64: expect compact (non-gemm) scales."); - NVTE_CHECK(output->amax.dptr != nullptr, "NVFP4 1x64: rowwise amax buffer is required."); - - const size_t rows = input.flat_first_dim(); - const size_t cols = input.flat_last_dim(); - if (rows == 0 || cols == 0) { - return; - } - NVTE_CHECK(cols % 16 == 0, "NVFP4 1x64: K must be a multiple of 16 (1x16 block size), got: ", - cols); - const size_t n_win = (cols + 63) / 64; - - uint8_t* out_ptr = reinterpret_cast(output->data.dptr); - fp8e4m3* scales = reinterpret_cast(output->scale_inv.dptr); - const size_t s_stride = output->scale_inv.shape.size() > 1 ? output->scale_inv.shape[1] : 1; - float* amax = reinterpret_cast(output->amax.dptr); - NVTE_CHECK_CUDA(cudaMemsetAsync(amax, 0, sizeof(float), stream)); - - dim3 grid(static_cast(n_win), static_cast(rows), 1); - constexpr int kBlock = 64; - - TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( - input.dtype(), IType, { - const IType* in_t = reinterpret_cast(input.data.dptr); - nvfp4_rowwise_1x64_per_tile<<>>( - in_t, rows, cols, static_cast(cols), out_ptr, scales, s_stride, amax, - reinterpret_cast(noop.data.dptr)); - NVTE_CHECK_CUDA(cudaGetLastError()); - }); - -#else - (void)input; - (void)noop; - (void)output; - (void)stream; - NVTE_ERROR("FP4 support requires CUDA 12.8+, but compile-time CUDA version is ", CUDA_VERSION); -#endif -} - -} // namespace nvfp4 -} // namespace dispatch -} // namespace transformer_engine diff --git a/transformer_engine/common/cast/nvfp4/quantize_nvfp4_1x64_rowwise.cuh b/transformer_engine/common/cast/nvfp4/quantize_nvfp4_1x64_rowwise.cuh deleted file mode 100644 index 42a55dfcb7..0000000000 --- a/transformer_engine/common/cast/nvfp4/quantize_nvfp4_1x64_rowwise.cuh +++ /dev/null @@ -1,31 +0,0 @@ -/************************************************************************* - * Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. - * - * See LICENSE for license information. - ************************************************************************/ - -/*! \file quantize_nvfp4_1x64_rowwise.cuh - * \brief NVFP4 rowwise cast with per-1x64-K-tile S_enc (non-RHT path; no columnwise / GEMM). - */ -#ifndef TRANSFORMER_ENGINE_QUANTIZE_NVFP4_1X64_ROWWISE_CUH_ -#define TRANSFORMER_ENGINE_QUANTIZE_NVFP4_1X64_ROWWISE_CUH_ - -#include - -#include "../../common.h" - -namespace transformer_engine { -namespace dispatch { -namespace nvfp4 { - -// Same TE NVFP4 math as quantize_transpose / vector_blockwise, but -// S_enc = (fp8_max*fp4_max)/max|x| over the current 1x64 K-tile in each row -// (per row, for each 64-stride K window). -void quantize_rowwise_1x64_local_encode(const Tensor& input, const Tensor& noop, Tensor* output, - const QuantizationConfig& quant_config, cudaStream_t stream); - -} // namespace nvfp4 -} // namespace dispatch -} // namespace transformer_engine - -#endif // TRANSFORMER_ENGINE_QUANTIZE_NVFP4_1X64_ROWWISE_CUH_ diff --git a/transformer_engine/pytorch/csrc/nvfp4_1x64.h b/transformer_engine/pytorch/csrc/nvfp4_1x64.h index 389241af2b..3e9372b812 100644 --- a/transformer_engine/pytorch/csrc/nvfp4_1x64.h +++ b/transformer_engine/pytorch/csrc/nvfp4_1x64.h @@ -5,8 +5,11 @@ ************************************************************************/ /*! \file nvfp4_1x64.h - * \brief Small helpers for NVFP4 per-1x64-K S_enc (env + config + preconditions), shared - * between NVFP4Quantizer and split_quantize to avoid duplicating policy. + * \brief Small helpers for NVFP4 hierarchical 1x64 cast (env + config + + * preconditions), shared between NVFP4Quantizer and split_quantize to + * avoid duplicating policy. The "rowwise" in the env-var name is + * historical -- the kernel now produces rowwise and/or columnwise + * (transposed) output in a single fused pass. */ #ifndef TRANSFORMER_ENGINE_PYTORCH_NVFP4_1X64_H_ #define TRANSFORMER_ENGINE_PYTORCH_NVFP4_1X64_H_ @@ -17,44 +20,37 @@ namespace transformer_engine::pytorch::nvfp4_1x64 { -/// Whether rowwise 1x64 local encode is requested (TE-wide env, same for single-tensor and split). +/// Whether the hierarchical 1x64 cast is requested. The env-var name retains +/// its original ROWWISE_ prefix for backward compatibility with users of the +/// rowwise-only kernel that shipped first; both directions are now supported. [[nodiscard]] inline bool local_encode_from_env() { return transformer_engine::getenv("NVTE_NVFP4_ROWWISE_1X64_LOCAL_ENCODE", false); } -/// Apply 2D mode, SR, and optional 1x64 flag to a quantization config (mirrors what NVFP4 needs). +/// Apply 2D mode, SR, and optional 1x64 flag to a quantization config. inline void config_apply(QuantizationConfigWrapper& cfg, bool nvfp4_2d, bool stochastic_rounding, - bool use_rowwise_1x64) { + bool use_1x64) { cfg.set_nvfp4_2d_quantization(nvfp4_2d); cfg.set_stochastic_rounding(stochastic_rounding); - cfg.set_nvfp4_rowwise_1x64_local_encode(use_rowwise_1x64); + cfg.set_nvfp4_rowwise_1x64_local_encode(use_1x64); } /// Preconditions for \p NVFP4Quantizer::quantize_impl (non-split). -inline void require_ok_for_non_split(bool with_rht, bool columnwise, bool sr) { +inline void require_ok_for_non_split(bool with_rht, bool /* columnwise */, bool sr) { if (!local_encode_from_env()) { return; } NVTE_CHECK( !with_rht, "NVTE_NVFP4_ROWWISE_1X64_LOCAL_ENCODE=1 requires non-RHT (e.g. NVTE_NVFP4_DISABLE_RHT=1)."); - NVTE_CHECK( - !columnwise, - "NVTE_NVFP4_ROWWISE_1X64_LOCAL_ENCODE=1 supports rowwise-only NVFP4 output."); NVTE_CHECK(!sr, "NVTE_NVFP4_ROWWISE_1X64_LOCAL_ENCODE=1 is incompatible with stochastic rounding."); } /// Preconditions for \p split_quantize (non-RHT path). -inline void require_ok_for_split(bool want_rowwise, bool have_columnwise, bool sr) { +inline void require_ok_for_split(bool /* want_rowwise */, bool /* have_columnwise */, bool sr) { if (!local_encode_from_env()) { return; } - NVTE_CHECK( - want_rowwise, - "NVTE_NVFP4_ROWWISE_1X64_LOCAL_ENCODE in split_quantize requires rowwise output."); - NVTE_CHECK(!have_columnwise, - "NVTE_NVFP4_ROWWISE_1X64_LOCAL_ENCODE in split_quantize does not support columnwise " - "output."); NVTE_CHECK( !sr, "NVTE_NVFP4_ROWWISE_1X64_LOCAL_ENCODE in split_quantize is incompatible with SR."); } diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index 070549442b..611c4cbbc9 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -1753,6 +1753,24 @@ std::pair NVFP4Quantizer::create_tensor(const std::ve at::Tensor columnwise_data_tensor, columnwise_scale_inv_tensor, amax_columnwise; const auto bit8_tensor_opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); const auto bit32_tensor_opts = at::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA); + + // In hierarchical 1x64 mode the ``amax`` slot is repurposed from a global + // ``(1,)`` scalar to a per-window FP32 buffer: ``(M, N/64)`` for rowwise + // and ``(N, M/64)`` for the transposed columnwise output. The kernel + // requires ``flat_first_dim % 64 == 0`` and ``flat_last_dim % 64 == 0``, + // which the dispatcher re-checks; we duplicate it here so allocation + // failures surface before we ever launch the kernel. + const bool use_1x64 = nvfp4_1x64::local_encode_from_env(); + if (use_1x64) { + NVTE_CHECK(flat_first_dim % 64 == 0, + "NVFP4 1x64 local encode requires the leading flattened dim to be a multiple of " + "64, got ", + flat_first_dim); + NVTE_CHECK(flat_last_dim % 64 == 0, + "NVFP4 1x64 local encode requires the last dim to be a multiple of 64, got ", + flat_last_dim); + } + if (rowwise_usage) { const std::vector scale_inv_shape_int64(rowwise_scale_inv_shape.begin(), rowwise_scale_inv_shape.end()); @@ -1760,7 +1778,13 @@ std::pair NVFP4Quantizer::create_tensor(const std::ve rowwise_scale_inv_tensor = at::empty(scale_inv_shape_int64, bit8_tensor_opts); // hadamard amax kernel will zero out pointer with ZeroAmaxKernel // nvte_compute_amax_with_config will zero out the pointer if needed - amax_rowwise = at::empty({1}, bit32_tensor_opts); + if (use_1x64) { + amax_rowwise = at::empty({static_cast(flat_first_dim), + static_cast(flat_last_dim) / 64}, + bit32_tensor_opts); + } else { + amax_rowwise = at::empty({1}, bit32_tensor_opts); + } } if (columnwise_usage) { const std::vector scale_inv_shape_int64(columnwise_scale_inv_shape.begin(), @@ -1775,7 +1799,13 @@ std::pair NVFP4Quantizer::create_tensor(const std::ve columnwise_scale_inv_tensor = at::empty(scale_inv_shape_int64, bit8_tensor_opts); // hadamard amax kernel will zero out pointer with ZeroAmaxKernel // nvte_compute_amax_with_config will zero out the pointer if needed - amax_columnwise = at::empty({1}, bit32_tensor_opts); + if (use_1x64) { + amax_columnwise = at::empty({static_cast(flat_last_dim), + static_cast(flat_first_dim) / 64}, + bit32_tensor_opts); + } else { + amax_columnwise = at::empty({1}, bit32_tensor_opts); + } } // Convert tensors to Python @@ -1848,7 +1878,10 @@ std::pair NVFP4Quantizer::create_tensor(const std::ve out_cpp.set_rowwise_data(rowwise_data_tensor.data_ptr(), DType::kFloat4E2M1, shape); out_cpp.set_rowwise_scale_inv(rowwise_scale_inv_tensor.data_ptr(), DType::kFloat8E4M3, rowwise_scale_inv_shape); - out_cpp.set_amax(amax_rowwise.data_ptr(), DType::kFloat32, std::vector{1}); + const std::vector amax_row_shape = + use_1x64 ? std::vector{flat_first_dim, flat_last_dim / 64} + : std::vector{1}; + out_cpp.set_amax(amax_rowwise.data_ptr(), DType::kFloat32, amax_row_shape); } if (columnwise_usage) { // enforce 2D shape to avoid [S, B, H] shape and B and be 1 @@ -1859,8 +1892,10 @@ std::pair NVFP4Quantizer::create_tensor(const std::ve col_data_shape_fp4); out_cpp.set_columnwise_scale_inv(columnwise_scale_inv_tensor.data_ptr(), DType::kFloat8E4M3, columnwise_scale_inv_shape); - out_cpp.set_columnwise_amax(amax_columnwise.data_ptr(), DType::kFloat32, - std::vector{1}); + const std::vector amax_col_shape = + use_1x64 ? std::vector{flat_last_dim, flat_first_dim / 64} + : std::vector{1}; + out_cpp.set_columnwise_amax(amax_columnwise.data_ptr(), DType::kFloat32, amax_col_shape); } out_cpp.set_with_gemm_swizzled_scales(with_gemm_swizzled_scales); this->set_quantization_params(&out_cpp); @@ -2306,7 +2341,12 @@ void NVFP4Quantizer::quantize_impl(const TensorWrapper& input, TensorWrapper& ou "Use with_post_rht_amax=true instead."); } } else { // Without RHT - if (compute_amax) { + // The hierarchical 1x64 cast computes per-window amax inside the kernel + // (writing ``(M, N/64)`` / ``(N, M/64)`` buffers into the ``amax`` slots); + // running ``nvte_compute_amax_with_config`` here would overwrite the + // per-window allocation with a single FP32 global amax and clobber the + // shape back to ``(1,)``. Skip the precompute for that path entirely. + if (compute_amax && !nvfp4_1x64::local_encode_from_env()) { // Amax pointers auto rowwise_amax_ptr = out.get_amax().data_ptr; auto columnwise_amax_ptr = out.get_columnwise_amax().data_ptr; diff --git a/transformer_engine/pytorch/custom_recipes/quantization_nvfp4_1x64.py b/transformer_engine/pytorch/custom_recipes/quantization_nvfp4_1x64.py index e15ad76890..cd1ac40c63 100644 --- a/transformer_engine/pytorch/custom_recipes/quantization_nvfp4_1x64.py +++ b/transformer_engine/pytorch/custom_recipes/quantization_nvfp4_1x64.py @@ -2,33 +2,38 @@ # # See LICENSE for license information. -"""Reference implementation for NVFP4 rowwise 1x64 local-encode quantization. +"""Reference implementation for NVFP4 hierarchical 1x64 cast (rowwise + columnwise). The hierarchical 1x64 + 1x16 scheme replaces the per-tensor encoding scaling factor used by stock NVFP4 with a per-1x64-K-window scaling factor; the four 1x16 sub-blocks inside a window share their parent ``S_enc``. The CUDA kernel that implements this lives in -``transformer_engine/common/cast/nvfp4/quantize_nvfp4_1x64_rowwise.cu`` and is +``transformer_engine/common/cast/nvfp4/quantize_nvfp4_1x64.cu`` and is dispatched when ``NVTE_NVFP4_ROWWISE_1X64_LOCAL_ENCODE=1``. This file mirrors that kernel's arithmetic in pure PyTorch so tests can -compare the kernel's output byte-for-byte against a Python oracle (the same -bit-exact methodology used by ``NVFP4QuantizerRef`` for the production NVFP4 -path). The arithmetic ordering and intermediate clamps are chosen to match -what the kernel does: +compare the kernel's output byte-for-byte against a Python oracle. The +arithmetic ordering and intermediate clamps are chosen to match what the +kernel does: * ``S_enc_tile = (FP8_MAX*FP4_MAX) / max(tile_amax, 1e-12)`` clamped to ``fp32_max``; * ``s_dec = saturating_cast(vec_max * S_enc_tile / FP4_MAX)`` (the ``1/FP4_MAX`` is folded into the ``S_enc`` multiplier to match ``compute_decoding_scaling_factor``); -* ``block_scale = S_enc_tile / fp32(s_dec)`` (matches ``__fdiv_rn`` in the - kernel); +* ``block_scale = S_enc_tile / fp32(s_dec)`` with an explicit ``s_dec == 0`` + short-circuit to 0 (matches the kernel's ``s_dec_f == 0.f`` branch); * ``q = round_fp4_satfinite(x_fp32 * block_scale)`` with values clamped to ``[-FP4_MAX, FP4_MAX]`` before packing. -Only the rowwise, non-RHT, non-2D, non-stochastic-rounding path is supported, -matching the kernel's preconditions. +For the columnwise (transposed) output the kernel runs the same math along +the original M direction with a 64x1 window; the reference implements this +by simply running the rowwise routine on ``x.T``. The window amax tensor +is exposed alongside ``data`` / ``scale`` so consumers can reconstruct +``S_enc_window`` (the per-block ``s_dec`` alone is not enough information +to dequantize correctly). + +Only the non-RHT, non-2D, non-stochastic-rounding path is supported. """ from __future__ import annotations @@ -56,59 +61,77 @@ @dataclasses.dataclass class RefNVFP4Tensor1x64: - """Container for the rowwise 1x64 reference output. + """Container for the hierarchical 1x64 reference output. Mirrors the subset of attributes that the bit-exact test inspects. - Naming follows ``quantization_nvfp4.RefNVFP4Tensor`` so the test reads the - same way as ``check_quantization_nvfp4_versus_reference``. Attributes ---------- data: - Packed FP4 bytes, ``(M, N // 2)`` ``uint8``. + Packed rowwise FP4 bytes, ``(M, N // 2)`` ``uint8``. scale: - Per-1x16-block decode scale (E4M3), ``(M, N // 16)`` ``float8_e4m3fn``. - global_amax_row: - Global tensor amax (1-D, single fp32 element). Equals the result of - the kernel's ``atomicMaxFloat`` over all per-tile amax values. + Per-1x16-block rowwise decode scale (E4M3), ``(M, N // 16)`` + ``float8_e4m3fn``. + window_amax_row: + Per-1x64-window rowwise amax, ``(M, N // 64)`` ``float32``. + ``S_enc_window`` is recoverable from this via + ``compute_global_encode_scaling_factor_FP4(window_amax)``; consumers + need it for correct dequantization (the per-block ``s_dec`` alone + does not contain enough information). + columnwise_data, columnwise_scale, window_amax_col: + Their columnwise (transposed) counterparts. The transposed FP4 data + has shape ``(N, M // 2)`` (matching the production cast+transpose + layout), the columnwise scales ``(N, M // 16)``, and the columnwise + window amax ``(N, M // 64)``. ``None`` if columnwise was not + requested. """ data: Optional[torch.Tensor] = None scale: Optional[torch.Tensor] = None - global_amax_row: Optional[torch.Tensor] = None + window_amax_row: Optional[torch.Tensor] = None + columnwise_data: Optional[torch.Tensor] = None + columnwise_scale: Optional[torch.Tensor] = None + window_amax_col: Optional[torch.Tensor] = None class NVFP4Quantizer1x64Ref: - """Reference implementation of the rowwise 1x64 local-encode kernel. + """Reference implementation of the hierarchical 1x64 cast kernel. - The constructor takes no parameters because the kernel itself does not - expose any -- columnwise output, RHT, 2D scaling, and stochastic rounding - are all rejected at dispatch time. Surfacing those as ctor flags here - would only invite the test to drift away from the kernel's actual - capabilities. + Constructor takes flags matching the kernel's two output-direction + switches; RHT, 2D scaling, and stochastic rounding are not exposed + because the kernel rejects them at dispatch time and we do not want + the reference to drift away from the kernel's actual capabilities. """ - def __init__(self) -> None: - # No configurable knobs; see class docstring. - pass + def __init__(self, rowwise: bool = True, columnwise: bool = False) -> None: + if not rowwise and not columnwise: + raise ValueError("At least one of rowwise / columnwise must be True.") + self.rowwise = rowwise + self.columnwise = columnwise @staticmethod - def _quantize_rowwise(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - """Run the 1x64 reference math on a 2D input. - - Returns - ------- - ``(qx, sx, global_amax)`` where shapes match the kernel's compact - rowwise layout: ``qx`` is ``(M, N // 2)`` ``uint8``, ``sx`` is - ``(M, N // BLOCK_K)`` ``float8_e4m3fn``, and ``global_amax`` is - ``(1,)`` ``float32``. + def _quantize_2d( + x: torch.Tensor, + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Run the 1x64 reference math on a 2D input along its trailing dim. + + Returns ``(qx, sx, window_amax)`` where ``qx`` is ``(M, N // 2)`` + ``uint8``, ``sx`` is ``(M, N // BLOCK_K)`` ``float8_e4m3fn``, and + ``window_amax`` is ``(M, N // WINDOW_K)`` ``float32``. + + The columnwise pass is implemented by calling this routine on + ``x.T.contiguous()``; both passes therefore share a single + arithmetic path, which is the cleanest way to keep the two + directions consistent with the kernel (the kernel itself shares + its math by re-using ``compute_decoding_scaling_factor`` / + ``compute_global_encode_scaling_factor_FP4`` between passes). """ if x.ndim != 2: raise ValueError(f"NVFP4Quantizer1x64Ref expects a 2D tensor, got {x.ndim}D") M, N = x.shape - if N % BLOCK_K != 0: + if N % WINDOW_K != 0: raise ValueError( - f"N={N} must be a multiple of BLOCK_K={BLOCK_K} (kernel hard requirement)" + f"N={N} must be a multiple of WINDOW_K={WINDOW_K} (kernel hard requirement)" ) device = x.device @@ -118,31 +141,20 @@ def _quantize_rowwise(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torc fp4_max = torch.tensor(FLOAT4_E2M1_MAX, device=device, dtype=torch.float32) fp8_max = torch.tensor(FLOAT8_E4M3_MAX, device=device, dtype=torch.float32) - # Pad K up to a multiple of WINDOW_K so the reshape into windows is - # well-defined. The kernel itself supports a partial last window via - # the ``win_len`` clamp; we emulate that by zero-padding here and - # trimming the padded columns out of qx/sx at the end (the padded - # blocks are uninitialised in the kernel output, so the test compares - # only the un-padded prefix). - pad_n = (WINDOW_K - N % WINDOW_K) % WINDOW_K - if pad_n > 0: - x_padded = torch.nn.functional.pad(x, (0, pad_n), mode="constant", value=0.0) - else: - x_padded = x.contiguous() - Np = x_padded.shape[1] + Np = N n_win = Np // WINDOW_K n_blk = Np // BLOCK_K - x_padded_fp32 = x_padded.to(torch.float32) - x_win = x_padded_fp32.view(M, n_win, WINDOW_K) - x_blk = x_padded_fp32.view(M, n_blk, BLOCK_K) + x_fp32 = x.to(torch.float32).contiguous() + x_win = x_fp32.view(M, n_win, WINDOW_K) + x_blk = x_fp32.view(M, n_blk, BLOCK_K) - # 1x64 tile amax. The kernel applies fmaxf(tile_amax, 1e-12f) to the - # divisor; do the same here. ``S_enc_tile`` is then computed exactly - # like ``compute_global_encode_scaling_factor_FP4``: divide, clamp to - # fp32_max, and fall back to 1.0 if the divisor or quotient is zero - # (the latter branch is dead given the floor but is mirrored for - # parity). + # 1x64 tile amax. The kernel applies ``fmaxf(tile_amax, 1e-12f)`` to + # the divisor; do the same here. ``S_enc_tile`` is then computed + # exactly like ``compute_global_encode_scaling_factor_FP4``: divide, + # clamp to fp32_max, and fall back to 1.0 if the divisor or quotient + # is zero (the latter branch is dead given the floor but is mirrored + # for parity with the C++ helper). tile_amax = torch.amax(torch.abs(x_win), dim=-1, keepdim=True) # (M, n_win, 1) tile_amax_safe = torch.clamp(tile_amax, min=_TILE_AMAX_FLOOR) @@ -154,16 +166,16 @@ def _quantize_rowwise(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torc S_enc_tile, ) - # Fold (1 / fp4_max) into the multiplier the same way the kernel does - # in ``compute_decoding_scaling_factor`` (``S_enc * fp4_max_inv``). + # Fold ``1 / fp4_max`` into the multiplier the same way the kernel + # does in ``compute_decoding_scaling_factor`` (``S_enc * fp4_max_inv``). # Keeping the operation order identical is what makes the resulting # E4M3 scale bit-exact with the kernel. S_enc_tile_mul_inv6 = S_enc_tile * torch.reciprocal(fp4_max) # 1x16 block amax and per-block S_enc broadcast. Each 1x64 window - # spans exactly BLOCKS_PER_WINDOW (=4) consecutive 1x16 sub-blocks, so - # repeat_interleave along the block axis aligns one S_enc_tile to - # every block inside that window. + # spans BLOCKS_PER_WINDOW (=4) consecutive 1x16 sub-blocks, so + # ``repeat_interleave`` along the block axis aligns one S_enc_tile + # to every block inside that window. vec_max = torch.amax(torch.abs(x_blk), dim=-1, keepdim=True) # (M, n_blk, 1) S_enc_per_blk = S_enc_tile.repeat_interleave(BLOCKS_PER_WINDOW, dim=1) S_enc_per_blk_mul = S_enc_tile_mul_inv6.repeat_interleave(BLOCKS_PER_WINDOW, dim=1) @@ -179,11 +191,11 @@ def _quantize_rowwise(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torc decode_scale_back_fp32 = decode_scale_e4m3.to(torch.float32) # block_scale = S_enc_tile / s_dec, matching ``__fdiv_rn`` in the - # kernel. Padded sub-blocks have vec_max == 0 hence s_dec == 0, which - # would yield +inf here and propagate NaN through the downstream - # multiply. To keep the division warning-free we replace zero - # divisors with 1.0, divide, then mask the result back to zero -- the - # padded slots are trimmed out of the final comparison either way. + # kernel. All-zero blocks have ``s_dec == 0``, which would yield + # ``+inf`` here and propagate NaN through the downstream multiply + # (``cvt.rn.satfinite.e2m1x4.f32(NaN)`` saturates to +FP4_MAX on + # SM10 -- we do NOT want that). Short-circuit to 0 to mirror the + # kernel's ``s_dec_f == 0.f`` branch. zero_blk = decode_scale_back_fp32 == 0 denom = torch.where(zero_blk, torch.ones_like(decode_scale_back_fp32), decode_scale_back_fp32) encode_scale = S_enc_per_blk / denom @@ -193,23 +205,29 @@ def _quantize_rowwise(x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torc # Apply scale, clamp to FP4 range, and pack two FP4 values per byte. scaled_x = x_blk * encode_scale clipped_x = torch.clamp(scaled_x, -fp4_max, fp4_max).reshape(M, Np) - qx_packed_padded = cast_to_fp4x2(clipped_x) # (M, Np // 2) - - sx_padded = decode_scale_e4m3.squeeze(-1) # (M, n_blk) + qx = cast_to_fp4x2(clipped_x).contiguous() # (M, N // 2) - # Trim the K-direction padding so the returned tensors describe only - # positions the kernel actually wrote to. - qx = qx_packed_padded[:, : N // 2].contiguous() - sx = sx_padded[:, : N // BLOCK_K].contiguous() + sx = decode_scale_e4m3.squeeze(-1).contiguous() # (M, n_blk) - # ``output.amax`` in the kernel accumulates ``atomicMaxFloat`` over - # every per-tile amax, which is mathematically max-of-maxes -- i.e. - # the global tensor amax. Compute that directly here. - global_amax = torch.amax(torch.abs(x.to(torch.float32))).reshape(1) + # Per-1x64-window amax, exposed for consumer-side dequantization. + window_amax = tile_amax.squeeze(-1).to(torch.float32).contiguous() # (M, n_win) - return qx, sx, global_amax + return qx, sx, window_amax def quantize(self, tensor: torch.Tensor) -> RefNVFP4Tensor1x64: """Quantize ``tensor`` and return a ``RefNVFP4Tensor1x64``.""" - qx, sx, global_amax = self._quantize_rowwise(tensor) - return RefNVFP4Tensor1x64(data=qx, scale=sx, global_amax_row=global_amax) + out = RefNVFP4Tensor1x64() + if self.rowwise: + qx, sx, win_amax = self._quantize_2d(tensor) + out.data = qx + out.scale = sx + out.window_amax_row = win_amax + if self.columnwise: + # The columnwise output is the rowwise quantization of the + # transpose; both directions share the same math and the same + # ``s_dec``/``block_scale`` chain. + qx_t, sx_t, win_amax_t = self._quantize_2d(tensor.transpose(0, 1).contiguous()) + out.columnwise_data = qx_t + out.columnwise_scale = sx_t + out.window_amax_col = win_amax_t + return out From 89aaa0b10262b79b159959b145cdff48709168b3 Mon Sep 17 00:00:00 2001 From: Cael Ling Date: Tue, 28 Apr 2026 19:16:45 -0700 Subject: [PATCH 07/10] Drop colwise group-quantize work from this PR Signed-off-by: Cael Ling --- .../nvfp4/group_quantize_transpose_nvfp4.cuh | 167 ++++++++---------- 1 file changed, 75 insertions(+), 92 deletions(-) diff --git a/transformer_engine/common/cast/nvfp4/group_quantize_transpose_nvfp4.cuh b/transformer_engine/common/cast/nvfp4/group_quantize_transpose_nvfp4.cuh index c827f466ed..a2f3dac15a 100644 --- a/transformer_engine/common/cast/nvfp4/group_quantize_transpose_nvfp4.cuh +++ b/transformer_engine/common/cast/nvfp4/group_quantize_transpose_nvfp4.cuh @@ -48,8 +48,6 @@ struct MultiAmaxCastTransposeFusionArgs { void *output_colwise_scale_inv_list[kMaxTensorsPerKernel]; // (Unused for rowwise only scaling) output scale stride for colwise scaling int output_colwise_scale_stride[kMaxTensorsPerKernel]; - // (Unused for rowwise only scaling) output data stride for colwise data (in fp4e2m1x2 units) - int output_colwise_data_stride[kMaxTensorsPerKernel]; // Prefix sum (with leading zero) of split_sections of each tensor of input int split_sections_range[kMaxTensorsPerKernel + 1]; // Number of tensors (splits) being processed by kernel @@ -90,7 +88,7 @@ __device__ __forceinline__ void UpdateEncodeDecodeScaleFP32(float *amax_ptr, flo float *s_dec_ptr) { float s_env_value = (amax_ptr == nullptr) ? 1.0f : compute_global_encode_scaling_factor_FP4(*amax_ptr); - float s_dec_value = 1.0f / s_env_value; + float s_dec_value = 1.0 / s_env_value; *s_enc_ptr = s_env_value; *s_dec_ptr = s_dec_value; return; @@ -204,15 +202,23 @@ __global__ void __launch_bounds__(THREADS_NUM) const size_t block_offset_Y = blockIdx.y * CHUNK_DIM_Y; const size_t block_offset_X = blockIdx.x * CHUNK_DIM_X; + // TODO(zhongbo): add back when transpose is supported + // const size_t block_offset_Y_t = blockIdx.x * CHUNK_DIM_X; + // const size_t block_offset_X_t = blockIdx.y * CHUNK_DIM_Y; + const size_t chunk_rows = rows - block_offset_Y; const size_t scales_block_offset_Y_rowwise = blockIdx.y * CHUNK_DIM_Y; const size_t scales_block_offset_X_rowwise = blockIdx.x * SCALES_PER_CHUNK_X; + // TODO(zhongbo): add back when transpose is supported + // const size_t scales_block_offset_Y_t = blockIdx.x * CHUNK_DIM_X; + // const size_t scales_block_offset_X_t = blockIdx.y * SCALES_PER_CHUNK_Y; const size_t tid_Y_rowwise = threadIdx.x / THREADS_X_ROWWISE; const size_t tid_X_rowwise = threadIdx.x % THREADS_X_ROWWISE; const size_t tid_X_colwise = threadIdx.x; const size_t tid_Y_t = tid_X_colwise; + // const size_t tid_X_t = 0; const size_t thread_offset_Y_rowwise = tid_Y_rowwise; const size_t thread_offset_X_rowwise = tid_X_rowwise * SCALE_DIM; @@ -226,11 +232,17 @@ __global__ void __launch_bounds__(THREADS_NUM) const size_t scales_offset_Y_rowwise = scales_block_offset_Y_rowwise + tid_Y_rowwise; const size_t scales_offset_X_rowwise = scales_block_offset_X_rowwise + tid_X_rowwise; + // TODO(zhongbo): add back when transpose is supported + // const size_t scales_offset_Y_t = scales_block_offset_Y_t + tid_Y_t; + // const size_t scales_offset_X_t = scales_block_offset_X_t; const size_t SFs_per_row = cols / SCALE_DIM; const bool rowwise_scale_is_within_bounds_X = scales_offset_X_rowwise < SFs_per_row; + // TODO(zhongbo): add back when transpose is supported + // const bool colwise_scale_is_within_bounds_Y = scales_offset_Y_t < cols; + // Helps resolving bank conflicts in shmem const int thread_lane = threadIdx.x % THREADS_PER_WARP; const int bank_group = thread_lane / THREADS_PER_BANK; @@ -271,14 +283,12 @@ __global__ void __launch_bounds__(THREADS_NUM) const bool is_master_thread = (threadIdx.x == 0); + // TODO (zhongbo): finish this float *amax_rowwise_ptr = nullptr; float *amax_colwise_ptr = nullptr; nvfp4_scale_t *split_rowwise_scale_ptr = nullptr; - fp4e2m1x2 *split_colwise_data_ptr = nullptr; - nvfp4_scale_t *split_colwise_scale_ptr = nullptr; - int split_colwise_data_stride = 0; - int split_colwise_scale_stride = 0; + // suppose the amax is fixed for the current 128x128 tile (need 128 padding) bool need_update_tensor_id = true; int tensor_id = GetTensorIdAndBoundary(&kernel_args, block_offset_Y, block_offset_Y + CHUNK_DIM_Y, &need_update_tensor_id); @@ -287,20 +297,12 @@ __global__ void __launch_bounds__(THREADS_NUM) amax_rowwise_ptr = reinterpret_cast(kernel_args.rowwise_amax_list[tensor_id]); split_rowwise_scale_ptr = reinterpret_cast(kernel_args.output_rowwise_scale_inv_list[tensor_id]); - if constexpr (RETURN_TRANSPOSE) { - amax_colwise_ptr = reinterpret_cast(kernel_args.colwise_amax_list[tensor_id]); - split_colwise_data_ptr = - reinterpret_cast(kernel_args.output_colwise_data_list[tensor_id]); - split_colwise_scale_ptr = - reinterpret_cast(kernel_args.output_colwise_scale_inv_list[tensor_id]); - split_colwise_data_stride = kernel_args.output_colwise_data_stride[tensor_id]; - split_colwise_scale_stride = kernel_args.output_colwise_scale_stride[tensor_id]; - } float S_enc_rowwise = 1.0f; float S_dec_rowwise = 1.0f; UpdateEncodeDecodeScaleFP32(amax_rowwise_ptr, &S_enc_rowwise, &S_dec_rowwise); + // TODO (zhongbo): colwise scaling disabled for now because of transpose float S_enc_colwise = 1.0f; float S_dec_colwise = 1.0f; if (amax_colwise_ptr != nullptr) { @@ -343,21 +345,8 @@ __global__ void __launch_bounds__(THREADS_NUM) UpdateEncodeDecodeScaleFP32(amax_rowwise_ptr, &S_enc_rowwise, &S_dec_rowwise); split_rowwise_scale_ptr = reinterpret_cast(kernel_args.output_rowwise_scale_inv_list[tensor_id]); - if constexpr (RETURN_TRANSPOSE) { - amax_colwise_ptr = reinterpret_cast(kernel_args.colwise_amax_list[tensor_id]); - if (amax_colwise_ptr != nullptr) { - UpdateEncodeDecodeScaleFP32(amax_colwise_ptr, &S_enc_colwise, &S_dec_colwise); - } else { - S_enc_colwise = S_enc_rowwise; - S_dec_colwise = S_dec_rowwise; - } - split_colwise_data_ptr = - reinterpret_cast(kernel_args.output_colwise_data_list[tensor_id]); - split_colwise_scale_ptr = reinterpret_cast( - kernel_args.output_colwise_scale_inv_list[tensor_id]); - split_colwise_data_stride = kernel_args.output_colwise_data_stride[tensor_id]; - split_colwise_scale_stride = kernel_args.output_colwise_scale_stride[tensor_id]; - } + // TODO (zhongbo): colwise scaling disabled for now because of transpose + // Skip fetching colwise amax pointer and scaling factor updates } } @@ -450,23 +439,6 @@ __global__ void __launch_bounds__(THREADS_NUM) tid_Y_t * SCALES_PER_CHUNK_Y + stage * ITERATIONS_TRANSPOSE + it; out_colwise_scales_sh[scale_idx_sh] = S_dec_b_fp8; - // Write colwise scale directly to per-split global buffer (streaming store) - if (split_colwise_scale_ptr != nullptr) { - const size_t global_row = block_offset_Y + stage_offset_Y + it * SCALE_DIM; - const size_t col_idx = block_offset_X + threadIdx.x; - const bool within_split = (global_row >= split_start) && (global_row < split_end); - if (!col_out_of_bounds_colwise && within_split) { - const size_t local_block = (global_row - split_start) / SCALE_DIM; - nvfp4_scale_t *scale_dst = - split_colwise_scale_ptr + col_idx * split_colwise_scale_stride + local_block; - asm volatile( - "st.global.cs.u8 [%0], %1;\n" - : - : "l"(scale_dst), "r"(static_cast(S_dec_b_fp8)) - : "memory"); - } - } - // Compute "correct" per-block encoding scaling factor constexpr float float_max = detail::TypeExtrema::max; const float block_scale_inverse = fminf( @@ -512,34 +484,6 @@ __global__ void __launch_bounds__(THREADS_NUM) out_t_data_sh_as_uint32_t[group] = val[0]; // idx1 = (group + 0) % 2; out_t_data_sh_as_uint32_t[(group + 1) & 1] = val[1]; // idx2 = (group + 1) % 2; } - - // Write colwise quantized data from shmem to per-split global buffers. - // Issued here (right after colwise quantize) so that the strided global - // stores overlap with the rowwise compute that follows. Each thread only - // reads its own 16-byte shmem region, so no __syncthreads is needed. - if (split_colwise_data_ptr != nullptr) { - const size_t global_stage_row = block_offset_Y + stage_offset_Y; - const size_t dst_row_t = block_offset_X + threadIdx.x; - const bool within_bounds = (dst_row_t < cols) && (global_stage_row < rows) && - (global_stage_row >= split_start) && - (global_stage_row < split_end); - if (within_bounds) { - const size_t local_stage_row = global_stage_row - split_start; - const size_t dst_byte_col = local_stage_row / 2; - fp4e2m1x2 *dst = - split_colwise_data_ptr + dst_row_t * split_colwise_data_stride + dst_byte_col; - const fp4e2m1x2 *src = - &out_t_data_sh[buff_offset_out_t + threadIdx.x * BUFF_OUT_T_DIM_X]; - // Use cache-streaming store to avoid L2 pollution from strided writes - const uint4 src_val = *reinterpret_cast(src); - asm volatile( - "st.global.cs.v4.u32 [%0], {%1, %2, %3, %4};\n" - : - : "l"(reinterpret_cast(dst)), - "r"(src_val.x), "r"(src_val.y), "r"(src_val.z), "r"(src_val.w) - : "memory"); - } - } } // ROWWISE scaling @@ -671,6 +615,12 @@ __global__ void __launch_bounds__(THREADS_NUM) const bool rowwise_scale_is_within_bounds_Y = (stage_rowwise_scales_offset_Y + it * THREADS_Y_ROWWISE) < chunk_rows; + // TODO(zhongbo): depending on input padding multiple (whether 128 or 64), use either scale_ptr or split_rowwise_scale_ptr + // const size_t scale_idx_global = scales_offset_Y * scale_stride + scales_offset_X; + // if (rowwise_scale_is_within_bounds_X && rowwise_scale_is_within_bounds_Y) { + // scales_ptr[scale_idx_global] = S_dec_b_fp8; + // } + // Map to local split coordinates const size_t split_rows = split_end - split_start; const size_t local_scale_row = scales_offset_Y - split_start; @@ -738,16 +688,46 @@ __global__ void __launch_bounds__(THREADS_NUM) const size_t global_offset_Y = block_offset_Y + stage_offset_Y; const size_t global_offset_X = block_offset_X; + // TODO(zhongbo): add back when transpose is supported + // const size_t global_offset_Y_t = block_offset_Y_t; + // const size_t global_offset_X_t = block_offset_X_t + stage_offset_Y; + ptx::cp_async_bulk_tensor_2d_shared_to_global( reinterpret_cast(&tensor_map_output), global_offset_X, global_offset_Y, reinterpret_cast(&out_data_sh[buff_offset_out])); + // TODO(zhongbo): add back when transpose is supported + // if constexpr (RETURN_TRANSPOSE) { + // ptx::cp_async_bulk_tensor_2d_shared_to_global( + // reinterpret_cast(&tensor_map_output_t), global_offset_X_t, + // global_offset_Y_t, reinterpret_cast(&out_t_data_sh[buff_offset_out_t])); + // } + // Create a "bulk async-group" out of the previous bulk copy operation. ptx::cp_async_bulk_commit_group(); } - } // end of stages + // TODO(zhongbo): add back when transpose is supported + // Vectorized store scaling factors through SHMEM + // if (RETURN_TRANSPOSE && colwise_scale_is_within_bounds_Y) { + // using ScalesVec = Vec; + // const size_t scale_idx_sh = tid_Y_t * SCALES_PER_CHUNK_Y; + // ScalesVec &scales_vec = *reinterpret_cast(&out_colwise_scales_sh[scale_idx_sh]); + // const size_t scale_idx_global = scales_offset_Y_t * scale_stride_t + scales_offset_X_t; + // const size_t count = // number of scales in Y dimension of this chunk + // (chunk_rows >= CHUNK_DIM_Y) ? SCALES_PER_CHUNK_Y : (chunk_rows / SCALE_DIM); + // nvfp4_scale_t *dst = &scales_t_ptr[scale_idx_global]; + // constexpr size_t vec_bytes = SCALES_PER_CHUNK_Y * sizeof(nvfp4_scale_t); + // if (count == SCALES_PER_CHUNK_Y && (reinterpret_cast(dst) % vec_bytes == 0)) { + // // Fast path: vectorized store when destination is properly aligned + // scales_vec.store_to(dst); + // } else { + // // Safe path: element-wise store for tails or unaligned destinations + // scales_vec.store_to_elts(dst, 0, count); + // } + // } + destroy_barriers(mbar, is_master_thread); #else NVTE_DEVICE_ERROR("sm_100 or higher is required."); @@ -773,16 +753,22 @@ void group_quantize_transpose(const Tensor &input, const Tensor *noop, "Number of tensors should be less than or equal to ", kMaxTensorsPerKernel); Tensor *output = nullptr; - // loop over the list to find the first non-empty tensor with actual data + // loop over the list to find the first non-empty tensor for (size_t i = 0; i < num_tensors; ++i) { - if (output_list[i]->has_data() && output_list[i]->data.dptr != nullptr) { + if (output_list[i]->has_data()) { output = output_list[i]; break; } } - NVTE_CHECK(output != nullptr, "No output tensor found with non-null data pointer."); + NVTE_CHECK(output != nullptr, "No output tensor found."); + // also check that the output has not null data pointer + NVTE_CHECK(output->data.dptr != nullptr, "Output data pointer is null."); + // If transposed output is allocated, return the transposed data. Otherwise, it's not necesary to + // return the transposed data. bool return_transpose = output->has_columnwise_data(); + // forbid return transpose for now because group quantize transpose is not supported yet + NVTE_CHECK(!return_transpose, "Return transpose is not supported for group quantize transpose."); // output_List is contiguous in memory, so take the first tensor as the contiguous output auto output_contiguous = output->data; @@ -817,20 +803,10 @@ void group_quantize_transpose(const Tensor &input, const Tensor *noop, reinterpret_cast(output_list[i]->amax.dptr); kernel_args.output_rowwise_scale_inv_list[kernel_args.num_tensors] = reinterpret_cast(output_list[i]->scale_inv.dptr); - if (return_transpose) { - kernel_args.colwise_amax_list[kernel_args.num_tensors] = - reinterpret_cast(output_list[i]->columnwise_amax.dptr); - kernel_args.output_colwise_data_list[kernel_args.num_tensors] = - reinterpret_cast(output_list[i]->columnwise_data.dptr); - kernel_args.output_colwise_scale_inv_list[kernel_args.num_tensors] = - reinterpret_cast(output_list[i]->columnwise_scale_inv.dptr); - kernel_args.output_colwise_data_stride[kernel_args.num_tensors] = - static_cast(output_list[i]->columnwise_data.shape[1] / 2); - kernel_args.output_colwise_scale_stride[kernel_args.num_tensors] = - static_cast(output_list[i]->columnwise_scale_inv.shape[1]); - } + // kernel_args.split_sections[kernel_args.num_tensors] = split_sections[i]; kernel_args.split_sections_range[kernel_args.num_tensors + 1] = kernel_args.split_sections_range[kernel_args.num_tensors] + split_sections[i]; + // check overflow NVTE_CHECK(kernel_args.split_sections_range[kernel_args.num_tensors + 1] >= 0, "split_sections_range overflow the int32_t"); kernel_args.num_tensors++; @@ -846,6 +822,8 @@ void group_quantize_transpose(const Tensor &input, const Tensor *noop, // for the colwise scaling, scaling factor stride is different for each tensor because of transpose // since transpose puts token dimension splits in the last dimension of the tensor const size_t scale_stride = output->scale_inv.shape[1]; + // const size_t scale_stride_transpose = + // return_transpose ? output->columnwise_scale_inv.shape[1] : 0; nvfp4_scale_t *const scales_ptr = reinterpret_cast(output->scale_inv.dptr); @@ -866,12 +844,17 @@ void group_quantize_transpose(const Tensor &input, const Tensor *noop, alignas(64) CUtensorMap tensor_map_input{}; alignas(64) CUtensorMap tensor_map_output{}; + // alignas(64) CUtensorMap tensor_map_output_transpose{}; create_2D_tensor_map(tensor_map_input, input.data, rows, cols, BUFF_DIM_Y, BUFF_DIM_X, cols, 0, sizeof(IType) * 8); create_2D_tensor_map(tensor_map_output, output_contiguous, rows, cols, BUFF_DIM_Y, BUFF_DIM_X, cols, 0, 4); + // if (return_transpose) { + // create_2D_tensor_map(tensor_map_output_transpose, output->columnwise_data, cols, rows, + // BUFF_DIM_X, BUFF_DIM_Y, rows, 0, 4); + // } constexpr size_t buff_elems = BUFF_DIM_Y * BUFF_DIM_X; constexpr size_t buff_elems_total = BUFFS_NUM * buff_elems; constexpr size_t buff_size_aligned_in = From 7c35e320d6471053802155acbd5b7f4c668b1a60 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 29 Apr 2026 03:56:25 +0000 Subject: [PATCH 08/10] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../compare_64x64_global_vs_1x64.py | 5 +- .../fp8_e4m3_utils_np.py | 4 +- .../hierarchical_nvfp4_ref.py | 32 +++------- .../hierarchical_nvfp4_ref_numpy.py | 34 ++++++---- .../pytorch/nvfp4/test_nvfp4_1x64_accuracy.py | 26 ++++---- .../nvfp4/test_nvfp4_1x64_quantize_exact.py | 20 ++---- .../common/cast/dispatch/quantize.cuh | 5 +- .../common/cast/nvfp4/quantize_nvfp4_1x64.cu | 62 +++++++++---------- .../pytorch/csrc/extensions/cast.cpp | 2 +- transformer_engine/pytorch/csrc/nvfp4_1x64.h | 12 ++-- transformer_engine/pytorch/csrc/quantizer.cpp | 21 +++---- .../custom_recipes/quantization_nvfp4_1x64.py | 8 +-- 12 files changed, 108 insertions(+), 123 deletions(-) diff --git a/reference_hierarchical_nvfp4/compare_64x64_global_vs_1x64.py b/reference_hierarchical_nvfp4/compare_64x64_global_vs_1x64.py index 78c6054504..cfdad13cfa 100644 --- a/reference_hierarchical_nvfp4/compare_64x64_global_vs_1x64.py +++ b/reference_hierarchical_nvfp4/compare_64x64_global_vs_1x64.py @@ -106,7 +106,10 @@ def main() -> None: fn = float(np.linalg.norm(x, "fro")) if fn > 0: print("||x||_F =", fn) - print("Fro 相对: ||recon_1x64 - recon_global||_F / ||x||_F =", float(np.linalg.norm(dg, "fro") / fn)) + print( + "Fro 相对: ||recon_1x64 - recon_global||_F / ||x||_F =", + float(np.linalg.norm(dg, "fro") / fn), + ) if __name__ == "__main__": diff --git a/reference_hierarchical_nvfp4/fp8_e4m3_utils_np.py b/reference_hierarchical_nvfp4/fp8_e4m3_utils_np.py index 46b8f87cfa..cecd42a537 100644 --- a/reference_hierarchical_nvfp4/fp8_e4m3_utils_np.py +++ b/reference_hierarchical_nvfp4/fp8_e4m3_utils_np.py @@ -28,9 +28,7 @@ def _decode_e4m3_byte(b: int) -> float: # Precompute 256 float32 values for all E4M3 codes -_E4M3_TABLE: np.ndarray = np.array( - [_decode_e4m3_byte(i) for i in range(256)], dtype=np.float32 -) +_E4M3_TABLE: np.ndarray = np.array([_decode_e4m3_byte(i) for i in range(256)], dtype=np.float32) def f32_to_e4m3_u8(x: np.ndarray) -> np.ndarray: diff --git a/reference_hierarchical_nvfp4/hierarchical_nvfp4_ref.py b/reference_hierarchical_nvfp4/hierarchical_nvfp4_ref.py index 3b92918752..8df743d781 100644 --- a/reference_hierarchical_nvfp4/hierarchical_nvfp4_ref.py +++ b/reference_hierarchical_nvfp4/hierarchical_nvfp4_ref.py @@ -177,11 +177,7 @@ def quantize_rowwise_1x64_1x16(x: torch.Tensor, eps: float = 1e-12) -> Hierarchi if bamax < eps: bamax = float(eps) raw = compute_S_dec_f32_before_cast_te(bamax, float(s_e.item())) - u = int( - _f32_e4m3_u8_np(np.array([raw], dtype=np.float32).reshape(1))[ - 0 - ] - ) + u = int(_f32_e4m3_u8_np(np.array([raw], dtype=np.float32).reshape(1))[0]) S_dec_u8[row, t16b] = u s_dec_f = max(float(_e4m3_u8_f32_np(np.array([u], dtype=np.uint8))[0]), TINY) bsi = float(s_e.item()) / s_dec_f @@ -196,9 +192,9 @@ def dequantize_rowwise(p: HierarchicalNVFP4Rowwise) -> torch.Tensor: q = _unpack_fp4_along_k(p.data_u8, k) j16 = (torch.arange(k, device=device) // FINE).long() j64 = (torch.arange(k, device=device) // COARSE).long() - sdec = torch.from_numpy( - _e4m3_u8_f32_np(p.S_dec_u8[:, j16].cpu().numpy().astype(np.uint8)) - ).to(device=device, dtype=torch.float32) + sdec = torch.from_numpy(_e4m3_u8_f32_np(p.S_dec_u8[:, j16].cpu().numpy().astype(np.uint8))).to( + device=device, dtype=torch.float32 + ) senc = p.S_enc[:, j64] sdec = torch.clamp(sdec, min=TINY) return (q * (sdec / senc)).to(torch.float32) @@ -214,9 +210,7 @@ def _amax_64_m(x: torch.Tensor) -> torch.Tensor: return a -def quantize_columnwise_1x64_1x16( - x: torch.Tensor, eps: float = 1e-12 -) -> HierarchicalNVFP4Colwise: +def quantize_columnwise_1x64_1x16(x: torch.Tensor, eps: float = 1e-12) -> HierarchicalNVFP4Colwise: assert x.dim() == 2 m, k = int(x.size(0)), int(x.size(1)) device = x.device @@ -242,11 +236,7 @@ def quantize_columnwise_1x64_1x16( if bamax < eps: bamax = float(eps) raw = compute_S_dec_f32_before_cast_te(bamax, float(s_e.item())) - u = int( - _f32_e4m3_u8_np(np.array([raw], dtype=np.float32).reshape(1))[ - 0 - ] - ) + u = int(_f32_e4m3_u8_np(np.array([raw], dtype=np.float32).reshape(1))[0]) S_dec_u8[t16b, col] = u s_dec_f = max(float(_e4m3_u8_f32_np(np.array([u], dtype=np.uint8))[0]), TINY) bsi = float(s_e.item()) / s_dec_f @@ -261,9 +251,9 @@ def dequantize_colwise(p: HierarchicalNVFP4Colwise) -> torch.Tensor: q = _unpack_fp4_along_m(p.data_u8, m, k) r16 = (torch.arange(m, device=device) // FINE).long() r64 = (torch.arange(m, device=device) // COARSE).long() - sdec = torch.from_numpy( - _e4m3_u8_f32_np(p.S_dec_u8[r16, :].cpu().numpy().astype(np.uint8)) - ).to(device=device, dtype=torch.float32) + sdec = torch.from_numpy(_e4m3_u8_f32_np(p.S_dec_u8[r16, :].cpu().numpy().astype(np.uint8))).to( + device=device, dtype=torch.float32 + ) senc = p.S_enc[r64, :] sdec = torch.clamp(sdec, min=TINY) return (q * (sdec / senc)).to(torch.float32) @@ -276,9 +266,7 @@ def reference_matmul_tn( return dequantize_rowwise(a_rows) @ dequantize_colwise(b_cols).T -def roundtrip_error( - x: torch.Tensor, mode: str -) -> Tuple[torch.Tensor, torch.Tensor]: +def roundtrip_error(x: torch.Tensor, mode: str) -> Tuple[torch.Tensor, torch.Tensor]: if mode == "rowwise": p = quantize_rowwise_1x64_1x16(x) y = dequantize_rowwise(p) diff --git a/reference_hierarchical_nvfp4/hierarchical_nvfp4_ref_numpy.py b/reference_hierarchical_nvfp4/hierarchical_nvfp4_ref_numpy.py index 9d3dc43830..970ac41917 100644 --- a/reference_hierarchical_nvfp4/hierarchical_nvfp4_ref_numpy.py +++ b/reference_hierarchical_nvfp4/hierarchical_nvfp4_ref_numpy.py @@ -39,8 +39,22 @@ FP4_E2M1_GRID = np.array( [ - 0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0, - -0.0, -0.5, -1.0, -1.5, -2.0, -3.0, -4.0, -6.0, + 0.0, + 0.5, + 1.0, + 1.5, + 2.0, + 3.0, + 4.0, + 6.0, + -0.0, + -0.5, + -1.0, + -1.5, + -2.0, + -3.0, + -4.0, + -6.0, ], dtype=np.float32, ) @@ -170,9 +184,7 @@ def quantize_rowwise_1x64_1x16(x: np.ndarray, eps: float = 1e-12) -> Hierarchica w[row, lo:hi] = segx * bsi t16b += 1 q = _round_to_nearest_fp4(w) - return HierarchicalNVFP4RowwiseNp( - m, k, _pack_fp4_along_k(q), S_enc, S_dec_u8, amax_64 - ) + return HierarchicalNVFP4RowwiseNp(m, k, _pack_fp4_along_k(q), S_enc, S_dec_u8, amax_64) def dequantize_rowwise(p: HierarchicalNVFP4RowwiseNp) -> np.ndarray: @@ -197,9 +209,7 @@ def _amax_64_m(x: np.ndarray, m: int, k: int) -> np.ndarray: return out -def quantize_columnwise_1x64_1x16( - x: np.ndarray, eps: float = 1e-12 -) -> HierarchicalNVFP4ColwiseNp: +def quantize_columnwise_1x64_1x16(x: np.ndarray, eps: float = 1e-12) -> HierarchicalNVFP4ColwiseNp: x = np.asarray(x, dtype=np.float32) assert x.ndim == 2 m, k = int(x.shape[0]), int(x.shape[1]) @@ -225,14 +235,14 @@ def quantize_columnwise_1x64_1x16( raw = compute_S_dec_f32_before_cast_te(bamax, S) u = f32_to_e4m3_u8(np.array(raw, dtype=np.float32).reshape(1)) S_dec_u8[t16b, col] = u.ravel()[0] - s_dec_f = max(float(e4m3_u8_to_f32(S_dec_u8[t16b : t16b + 1, col : col + 1])[0, 0]), TINY) + s_dec_f = max( + float(e4m3_u8_to_f32(S_dec_u8[t16b : t16b + 1, col : col + 1])[0, 0]), TINY + ) bsi = S / s_dec_f w[lo:hi, col] = segx * bsi t16b += 1 q = _round_to_nearest_fp4(w) - return HierarchicalNVFP4ColwiseNp( - m, k, _pack_fp4_along_m(q), S_enc, S_dec_u8, amax_64 - ) + return HierarchicalNVFP4ColwiseNp(m, k, _pack_fp4_along_m(q), S_enc, S_dec_u8, amax_64) def dequantize_colwise(p: HierarchicalNVFP4ColwiseNp) -> np.ndarray: diff --git a/tests/pytorch/nvfp4/test_nvfp4_1x64_accuracy.py b/tests/pytorch/nvfp4/test_nvfp4_1x64_accuracy.py index 79251ba089..84361a1456 100644 --- a/tests/pytorch/nvfp4/test_nvfp4_1x64_accuracy.py +++ b/tests/pytorch/nvfp4/test_nvfp4_1x64_accuracy.py @@ -152,9 +152,7 @@ def _err_metrics(x: torch.Tensor, recon: torch.Tensor) -> Tuple[float, float, fl return rmse, max_err, frob_rel -def _gen_gaussian( - M: int, N: int, *, seed: int, device: str, dtype: torch.dtype -) -> torch.Tensor: +def _gen_gaussian(M: int, N: int, *, seed: int, device: str, dtype: torch.dtype) -> torch.Tensor: """Uniform N(0, 1) -- a benign baseline where both schemes should tie.""" g = torch.Generator(device=device).manual_seed(seed) return torch.randn((M, N), generator=g, device=device, dtype=dtype) @@ -189,9 +187,7 @@ def _gen_per_window_dynamic_range( log_scales = torch.empty((M, n_win, 1), device=device, dtype=torch.float32) log_scales.uniform_(log10_lo, log10_hi, generator=g) scales = torch.pow(torch.tensor(10.0, device=device, dtype=torch.float32), log_scales) - base = torch.randn( - (M, n_win, WINDOW_K), generator=g, device=device, dtype=torch.float32 - ) + base = torch.randn((M, n_win, WINDOW_K), generator=g, device=device, dtype=torch.float32) x = (base * scales).reshape(M, n_win * WINDOW_K)[:, :N].contiguous() return x.to(dtype) @@ -243,9 +239,13 @@ def _gen_sparse_extreme_outliers( flat_idx = torch.randperm(total_windows, generator=g, device=device)[:n_outlier_windows] outlier_rows = flat_idx // n_win outlier_cols = (flat_idx % n_win) * WINDOW_K - signs = torch.randint( - 0, 2, (n_outlier_windows,), generator=g, device=device, dtype=torch.int32 - ).to(torch.float32) * 2 - 1 + signs = ( + torch.randint(0, 2, (n_outlier_windows,), generator=g, device=device, dtype=torch.int32).to( + torch.float32 + ) + * 2 + - 1 + ) x[outlier_rows, outlier_cols] = signs * float(outlier_mag) return x.to(dtype) @@ -288,7 +288,7 @@ def test_1x64_at_least_as_good_as_per_tensor_on_gaussian( ) assert rmse_1x64 <= rmse_pt * 1.05, ( - f"1x64 RMSE unexpectedly worse than per-tensor on uniform input: " + "1x64 RMSE unexpectedly worse than per-tensor on uniform input: " f"rmse_1x64={rmse_1x64:.4e} > 1.05 * rmse_pt={rmse_pt:.4e}" ) @@ -341,8 +341,8 @@ def test_1x64_better_than_per_tensor_on_sparse_extreme_outliers( ) assert rmse_1x64 < rmse_pt, ( - f"1x64 was not strictly better than per-tensor on " - f"sparse-extreme-outlier input: " + "1x64 was not strictly better than per-tensor on " + "sparse-extreme-outlier input: " f"rmse_1x64={rmse_1x64:.4e} >= rmse_pt={rmse_pt:.4e}" ) @@ -382,6 +382,6 @@ def test_1x64_at_least_tied_on_modest_per_window_dynamic_range( ) assert rmse_1x64 <= rmse_pt * 1.05, ( - f"1x64 unexpectedly worse than per-tensor on modest dynamic-range " + "1x64 unexpectedly worse than per-tensor on modest dynamic-range " f"input: rmse_1x64={rmse_1x64:.4e} > 1.05 * rmse_pt={rmse_pt:.4e}" ) diff --git a/tests/pytorch/nvfp4/test_nvfp4_1x64_quantize_exact.py b/tests/pytorch/nvfp4/test_nvfp4_1x64_quantize_exact.py index 324f9abd4e..7a56b0f8c2 100644 --- a/tests/pytorch/nvfp4/test_nvfp4_1x64_quantize_exact.py +++ b/tests/pytorch/nvfp4/test_nvfp4_1x64_quantize_exact.py @@ -94,9 +94,7 @@ def _check_quantization_1x64_versus_reference_with_input( torch.testing.assert_close( sxt_sut[: sxt_ref.shape[0], : sxt_ref.shape[1]], sxt_ref, atol=0.0, rtol=0.0 ) - torch.testing.assert_close( - sut._amax_columnwise, ref.window_amax_col, atol=0.0, rtol=0.0 - ) + torch.testing.assert_close(sut._amax_columnwise, ref.window_amax_col, atol=0.0, rtol=0.0) def _check_random( @@ -113,9 +111,7 @@ def _check_random( torch.manual_seed(seed) torch.cuda.manual_seed(seed) x = torch.randn((M, N), dtype=x_dtype, device=device) - _check_quantization_1x64_versus_reference_with_input( - x, rowwise=rowwise, columnwise=columnwise - ) + _check_quantization_1x64_versus_reference_with_input(x, rowwise=rowwise, columnwise=columnwise) # Shapes where both M and N are multiples of 64 -- the 1x64 hierarchy's @@ -141,9 +137,7 @@ def _check_random( @pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe) @pytest.mark.parametrize("M, N", _SHAPES_64x64_MULTIPLE) @pytest.mark.parametrize("x_dtype", [torch.float32, torch.bfloat16], ids=str) -def test_nvfp4_1x64_quantize_rowwise( - monkeypatch, x_dtype: torch.dtype, M: int, N: int -) -> None: +def test_nvfp4_1x64_quantize_rowwise(monkeypatch, x_dtype: torch.dtype, M: int, N: int) -> None: """Rowwise-only configuration -- preserves the original PR's coverage.""" monkeypatch.setenv("NVTE_NVFP4_ROWWISE_1X64_LOCAL_ENCODE", "1") monkeypatch.setenv("NVTE_NVFP4_DISABLE_RHT", "1") @@ -153,9 +147,7 @@ def test_nvfp4_1x64_quantize_rowwise( @pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe) @pytest.mark.parametrize("M, N", _SHAPES_64x64_MULTIPLE) @pytest.mark.parametrize("x_dtype", [torch.float32, torch.bfloat16], ids=str) -def test_nvfp4_1x64_quantize_columnwise( - monkeypatch, x_dtype: torch.dtype, M: int, N: int -) -> None: +def test_nvfp4_1x64_quantize_columnwise(monkeypatch, x_dtype: torch.dtype, M: int, N: int) -> None: """Columnwise-only -- exercises the transposed output path on its own.""" monkeypatch.setenv("NVTE_NVFP4_ROWWISE_1X64_LOCAL_ENCODE", "1") monkeypatch.setenv("NVTE_NVFP4_DISABLE_RHT", "1") @@ -230,6 +222,4 @@ def test_nvfp4_1x64_quantize_extrema( # path simultaneously for that row's columns it spans. x[2] = 0.75 - _check_quantization_1x64_versus_reference_with_input( - x, rowwise=rowwise, columnwise=columnwise - ) + _check_quantization_1x64_versus_reference_with_input(x, rowwise=rowwise, columnwise=columnwise) diff --git a/transformer_engine/common/cast/dispatch/quantize.cuh b/transformer_engine/common/cast/dispatch/quantize.cuh index 969e0b09ef..f4e16077a0 100644 --- a/transformer_engine/common/cast/dispatch/quantize.cuh +++ b/transformer_engine/common/cast/dispatch/quantize.cuh @@ -248,8 +248,9 @@ void quantize_bwd_helper(const NVTETensor grad, const NVTETensor input, NVTETens CheckInputTensor(*grad_tensor, "input"); CheckOutputTensor(*output_tensor, "output", false); - NVTE_CHECK(!quant_config_cpp.nvfp4_rowwise_1x64_local_encode, - "NVFP4 rowwise 1x64 local encode is not implemented for backward quantization yet."); + NVTE_CHECK( + !quant_config_cpp.nvfp4_rowwise_1x64_local_encode, + "NVFP4 rowwise 1x64 local encode is not implemented for backward quantization yet."); // Choose kernel int32_t rows = grad_tensor->flat_first_dim(); diff --git a/transformer_engine/common/cast/nvfp4/quantize_nvfp4_1x64.cu b/transformer_engine/common/cast/nvfp4/quantize_nvfp4_1x64.cu index becb07a2ea..d31bab6ee3 100644 --- a/transformer_engine/common/cast/nvfp4/quantize_nvfp4_1x64.cu +++ b/transformer_engine/common/cast/nvfp4/quantize_nvfp4_1x64.cu @@ -4,9 +4,9 @@ * See LICENSE for license information. ************************************************************************/ -#include "common/common.h" -#include "common/cast/nvfp4/quantize_nvfp4_1x64.cuh" #include "common/cast/nvfp4/core_nvfp4.cuh" +#include "common/cast/nvfp4/quantize_nvfp4_1x64.cuh" +#include "common/common.h" #include "common/util/ptx.cuh" #include "common/utils.cuh" @@ -17,9 +17,9 @@ namespace { #if FP4_TYPE_SUPPORTED +using core::compute_global_encode_scaling_factor_FP4; using ptx::FPx2; using quantization_SF::compute_decoding_scaling_factor; -using core::compute_global_encode_scaling_factor_FP4; // One CUDA block = one 64x64 input tile in (M, N) row-major space. // @@ -36,17 +36,17 @@ using core::compute_global_encode_scaling_factor_FP4; // (``in_sm[e][tid]`` walking down a column) does not fall on the same // 32-bank lane for every row. template -__global__ void __launch_bounds__(64) nvfp4_1x64_fused_per_tile( - const IType* __restrict__ in, const size_t rows, const size_t cols, const int ld_row_elts, - // Rowwise outputs (all three are non-null together, or all null). - uint8_t* __restrict__ q_row, fp8e4m3* __restrict__ s_dec_row, - float* __restrict__ w_amax_row, const size_t s_dec_row_stride, - const size_t w_amax_row_stride, - // Columnwise (transposed) outputs (all three together, or all null). - uint8_t* __restrict__ q_col, fp8e4m3* __restrict__ s_dec_col, - float* __restrict__ w_amax_col, const size_t s_dec_col_stride, - const size_t w_amax_col_stride, - const float* __restrict__ noop) { +__global__ void __launch_bounds__(64) + nvfp4_1x64_fused_per_tile(const IType* __restrict__ in, const size_t rows, const size_t cols, + const int ld_row_elts, + // Rowwise outputs (all three are non-null together, or all null). + uint8_t* __restrict__ q_row, fp8e4m3* __restrict__ s_dec_row, + float* __restrict__ w_amax_row, const size_t s_dec_row_stride, + const size_t w_amax_row_stride, + // Columnwise (transposed) outputs (all three together, or all null). + uint8_t* __restrict__ q_col, fp8e4m3* __restrict__ s_dec_col, + float* __restrict__ w_amax_col, const size_t s_dec_col_stride, + const size_t w_amax_col_stride, const float* __restrict__ noop) { #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) if (noop != nullptr && noop[0] == 1.0f) { return; @@ -79,8 +79,7 @@ __global__ void __launch_bounds__(64) nvfp4_1x64_fused_per_tile( #pragma unroll for (int e = 0; e < 64; e++) { const int gc = col_base + e; - in_sm[tid][e] = - static_cast(in[static_cast(gr) * ld_row_elts + gc]); + in_sm[tid][e] = static_cast(in[static_cast(gr) * ld_row_elts + gc]); } } else { #pragma unroll @@ -235,8 +234,7 @@ void quantize_1x64_local_encode(const Tensor& input, const Tensor& noop, Tensor* s_dec_row_stride = output->scale_inv.shape.size() > 1 ? output->scale_inv.shape[1] : 1; w_amax_row_stride = output->amax.shape.size() > 1 ? output->amax.shape[1] : 1; NVTE_CHECK(s_dec_row_stride == cols / 16, - "NVFP4 1x64: rowwise scale_inv stride must equal cols/16, got ", - s_dec_row_stride); + "NVFP4 1x64: rowwise scale_inv stride must equal cols/16, got ", s_dec_row_stride); NVTE_CHECK(w_amax_row_stride == cols / 64, "NVFP4 1x64: rowwise amax stride must equal cols/64, got ", w_amax_row_stride); } @@ -254,12 +252,10 @@ void quantize_1x64_local_encode(const Tensor& input, const Tensor& noop, Tensor* q_col = reinterpret_cast(output->columnwise_data.dptr); s_dec_col = reinterpret_cast(output->columnwise_scale_inv.dptr); w_amax_col = reinterpret_cast(output->columnwise_amax.dptr); - s_dec_col_stride = output->columnwise_scale_inv.shape.size() > 1 - ? output->columnwise_scale_inv.shape[1] - : 1; - w_amax_col_stride = output->columnwise_amax.shape.size() > 1 - ? output->columnwise_amax.shape[1] - : 1; + s_dec_col_stride = + output->columnwise_scale_inv.shape.size() > 1 ? output->columnwise_scale_inv.shape[1] : 1; + w_amax_col_stride = + output->columnwise_amax.shape.size() > 1 ? output->columnwise_amax.shape[1] : 1; NVTE_CHECK(s_dec_col_stride == rows / 16, "NVFP4 1x64: columnwise scale_inv stride must equal rows/16, got ", s_dec_col_stride); @@ -272,16 +268,14 @@ void quantize_1x64_local_encode(const Tensor& input, const Tensor& noop, Tensor* dim3 grid(static_cast(n_win), static_cast(m_tiles), 1); constexpr int kBlock = 64; - TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( - input.dtype(), IType, { - const IType* in_t = reinterpret_cast(input.data.dptr); - nvfp4_1x64_fused_per_tile<<>>( - in_t, rows, cols, static_cast(cols), q_row, s_dec_row, w_amax_row, - s_dec_row_stride, w_amax_row_stride, q_col, s_dec_col, w_amax_col, - s_dec_col_stride, w_amax_col_stride, - reinterpret_cast(noop.data.dptr)); - NVTE_CHECK_CUDA(cudaGetLastError()); - }); + TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY(input.dtype(), IType, { + const IType* in_t = reinterpret_cast(input.data.dptr); + nvfp4_1x64_fused_per_tile<<>>( + in_t, rows, cols, static_cast(cols), q_row, s_dec_row, w_amax_row, s_dec_row_stride, + w_amax_row_stride, q_col, s_dec_col, w_amax_col, s_dec_col_stride, w_amax_col_stride, + reinterpret_cast(noop.data.dptr)); + NVTE_CHECK_CUDA(cudaGetLastError()); + }); #else (void)input; diff --git a/transformer_engine/pytorch/csrc/extensions/cast.cpp b/transformer_engine/pytorch/csrc/extensions/cast.cpp index 2ce423b358..ee0cdb61dc 100644 --- a/transformer_engine/pytorch/csrc/extensions/cast.cpp +++ b/transformer_engine/pytorch/csrc/extensions/cast.cpp @@ -1263,7 +1263,7 @@ void split_quantize_nvfp4_impl_helper(const TensorWrapper &input, quantizer.stochastic_rounding, use_rowwise_1x64); } nvfp4_1x64::require_ok_for_split(quantizer.rowwise_usage, quantizer.columnwise_usage, - quantizer.stochastic_rounding); + quantizer.stochastic_rounding); if (!use_rowwise_1x64) { split_nvfp4_non_rht_run_grouped_amax( diff --git a/transformer_engine/pytorch/csrc/nvfp4_1x64.h b/transformer_engine/pytorch/csrc/nvfp4_1x64.h index 3e9372b812..e61e4db784 100644 --- a/transformer_engine/pytorch/csrc/nvfp4_1x64.h +++ b/transformer_engine/pytorch/csrc/nvfp4_1x64.h @@ -14,9 +14,10 @@ #ifndef TRANSFORMER_ENGINE_PYTORCH_NVFP4_1X64_H_ #define TRANSFORMER_ENGINE_PYTORCH_NVFP4_1X64_H_ +#include + #include "common/util/logging.h" #include "common/util/system.h" -#include namespace transformer_engine::pytorch::nvfp4_1x64 { @@ -29,7 +30,7 @@ namespace transformer_engine::pytorch::nvfp4_1x64 { /// Apply 2D mode, SR, and optional 1x64 flag to a quantization config. inline void config_apply(QuantizationConfigWrapper& cfg, bool nvfp4_2d, bool stochastic_rounding, - bool use_1x64) { + bool use_1x64) { cfg.set_nvfp4_2d_quantization(nvfp4_2d); cfg.set_stochastic_rounding(stochastic_rounding); cfg.set_nvfp4_rowwise_1x64_local_encode(use_1x64); @@ -43,7 +44,8 @@ inline void require_ok_for_non_split(bool with_rht, bool /* columnwise */, bool NVTE_CHECK( !with_rht, "NVTE_NVFP4_ROWWISE_1X64_LOCAL_ENCODE=1 requires non-RHT (e.g. NVTE_NVFP4_DISABLE_RHT=1)."); - NVTE_CHECK(!sr, "NVTE_NVFP4_ROWWISE_1X64_LOCAL_ENCODE=1 is incompatible with stochastic rounding."); + NVTE_CHECK(!sr, + "NVTE_NVFP4_ROWWISE_1X64_LOCAL_ENCODE=1 is incompatible with stochastic rounding."); } /// Preconditions for \p split_quantize (non-RHT path). @@ -51,8 +53,8 @@ inline void require_ok_for_split(bool /* want_rowwise */, bool /* have_columnwis if (!local_encode_from_env()) { return; } - NVTE_CHECK( - !sr, "NVTE_NVFP4_ROWWISE_1X64_LOCAL_ENCODE in split_quantize is incompatible with SR."); + NVTE_CHECK(!sr, + "NVTE_NVFP4_ROWWISE_1X64_LOCAL_ENCODE in split_quantize is incompatible with SR."); } } // namespace transformer_engine::pytorch::nvfp4_1x64 diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index 611c4cbbc9..4ba23fee94 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -1779,9 +1779,9 @@ std::pair NVFP4Quantizer::create_tensor(const std::ve // hadamard amax kernel will zero out pointer with ZeroAmaxKernel // nvte_compute_amax_with_config will zero out the pointer if needed if (use_1x64) { - amax_rowwise = at::empty({static_cast(flat_first_dim), - static_cast(flat_last_dim) / 64}, - bit32_tensor_opts); + amax_rowwise = at::empty( + {static_cast(flat_first_dim), static_cast(flat_last_dim) / 64}, + bit32_tensor_opts); } else { amax_rowwise = at::empty({1}, bit32_tensor_opts); } @@ -1800,9 +1800,9 @@ std::pair NVFP4Quantizer::create_tensor(const std::ve // hadamard amax kernel will zero out pointer with ZeroAmaxKernel // nvte_compute_amax_with_config will zero out the pointer if needed if (use_1x64) { - amax_columnwise = at::empty({static_cast(flat_last_dim), - static_cast(flat_first_dim) / 64}, - bit32_tensor_opts); + amax_columnwise = at::empty( + {static_cast(flat_last_dim), static_cast(flat_first_dim) / 64}, + bit32_tensor_opts); } else { amax_columnwise = at::empty({1}, bit32_tensor_opts); } @@ -1879,8 +1879,7 @@ std::pair NVFP4Quantizer::create_tensor(const std::ve out_cpp.set_rowwise_scale_inv(rowwise_scale_inv_tensor.data_ptr(), DType::kFloat8E4M3, rowwise_scale_inv_shape); const std::vector amax_row_shape = - use_1x64 ? std::vector{flat_first_dim, flat_last_dim / 64} - : std::vector{1}; + use_1x64 ? std::vector{flat_first_dim, flat_last_dim / 64} : std::vector{1}; out_cpp.set_amax(amax_rowwise.data_ptr(), DType::kFloat32, amax_row_shape); } if (columnwise_usage) { @@ -1893,8 +1892,7 @@ std::pair NVFP4Quantizer::create_tensor(const std::ve out_cpp.set_columnwise_scale_inv(columnwise_scale_inv_tensor.data_ptr(), DType::kFloat8E4M3, columnwise_scale_inv_shape); const std::vector amax_col_shape = - use_1x64 ? std::vector{flat_last_dim, flat_first_dim / 64} - : std::vector{1}; + use_1x64 ? std::vector{flat_last_dim, flat_first_dim / 64} : std::vector{1}; out_cpp.set_columnwise_amax(amax_columnwise.data_ptr(), DType::kFloat32, amax_col_shape); } out_cpp.set_with_gemm_swizzled_scales(with_gemm_swizzled_scales); @@ -2265,7 +2263,8 @@ void NVFP4Quantizer::quantize_impl(const TensorWrapper& input, TensorWrapper& ou } nvfp4_1x64::config_apply(quant_config, this->with_2d_quantization, this->stochastic_rounding, nvfp4_1x64::local_encode_from_env()); - nvfp4_1x64::require_ok_for_non_split(this->with_rht, this->columnwise_usage, this->stochastic_rounding); + nvfp4_1x64::require_ok_for_non_split(this->with_rht, this->columnwise_usage, + this->stochastic_rounding); // We only need RHT for columnwise usage. // flat first dim and last dim for multi dimensional input diff --git a/transformer_engine/pytorch/custom_recipes/quantization_nvfp4_1x64.py b/transformer_engine/pytorch/custom_recipes/quantization_nvfp4_1x64.py index cd1ac40c63..1a554c62f0 100644 --- a/transformer_engine/pytorch/custom_recipes/quantization_nvfp4_1x64.py +++ b/transformer_engine/pytorch/custom_recipes/quantization_nvfp4_1x64.py @@ -135,9 +135,7 @@ def _quantize_2d( ) device = x.device - fp32_max = torch.tensor( - torch.finfo(torch.float32).max, device=device, dtype=torch.float32 - ) + fp32_max = torch.tensor(torch.finfo(torch.float32).max, device=device, dtype=torch.float32) fp4_max = torch.tensor(FLOAT4_E2M1_MAX, device=device, dtype=torch.float32) fp8_max = torch.tensor(FLOAT8_E4M3_MAX, device=device, dtype=torch.float32) @@ -197,7 +195,9 @@ def _quantize_2d( # SM10 -- we do NOT want that). Short-circuit to 0 to mirror the # kernel's ``s_dec_f == 0.f`` branch. zero_blk = decode_scale_back_fp32 == 0 - denom = torch.where(zero_blk, torch.ones_like(decode_scale_back_fp32), decode_scale_back_fp32) + denom = torch.where( + zero_blk, torch.ones_like(decode_scale_back_fp32), decode_scale_back_fp32 + ) encode_scale = S_enc_per_blk / denom encode_scale = torch.where(zero_blk, torch.zeros_like(encode_scale), encode_scale) encode_scale = torch.minimum(encode_scale, fp32_max) From 1f1b16cc289643854156c687220d823d59838ec1 Mon Sep 17 00:00:00 2001 From: Cael Ling Date: Wed, 29 Apr 2026 02:03:09 -0700 Subject: [PATCH 09/10] Add NVFP4 1x64 local-encode support to split_quantize Signed-off-by: Cael Ling --- .../nvfp4/test_nvfp4_1x64_split_quantize.py | 360 ++++++++++++++++++ .../pytorch/csrc/extensions/cast.cpp | 133 +++---- 2 files changed, 421 insertions(+), 72 deletions(-) create mode 100644 tests/pytorch/nvfp4/test_nvfp4_1x64_split_quantize.py diff --git a/tests/pytorch/nvfp4/test_nvfp4_1x64_split_quantize.py b/tests/pytorch/nvfp4/test_nvfp4_1x64_split_quantize.py new file mode 100644 index 0000000000..6ef744f309 --- /dev/null +++ b/tests/pytorch/nvfp4/test_nvfp4_1x64_split_quantize.py @@ -0,0 +1,360 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Bit-exact tests for NVFP4 1x64 local-encode through ``tex.split_quantize``. + +Together with ``test_nvfp4_1x64_quantize_exact.py`` (which covers single-tensor +quantize) these tests lock down the multi-chunk dispatch path that drives +``split_quantize_nvfp4_impl_helper`` (and its ``UNFUSED`` fallback) when the +1x64 local-encode flag is on. Two complementary axes are exercised: + +* **Axis A -- algorithm oracle.** + Compare the SUT (``tex.split_quantize``) against a pure-PyTorch oracle + (``NVFP4Quantizer1x64Ref``) byte-for-byte, applied per chunk. The oracle is + independent of any TransformerEngine CUDA kernel, so any deviation must + come from the kernel itself (rounding, scale derivation, packed FP4 + layout). This catches algorithmic regressions even when both C++ + dispatch paths are wired identically. + +* **Axis B -- wiring oracle.** + Compare the SUT against ``quantizers[i](chunk)`` running per chunk via + ``reference_group_quantize``. Both sides ultimately invoke the same + ``quantize_1x64_local_encode`` kernel, so their outputs must be + bit-identical. Any mismatch points at the split-quantize driver itself -- + buffer allocation (``BULK_NVFP4`` vs. ``UNFUSED``), per-chunk config + propagation, ``nvte_quantize_v2`` arguments, or the inner-dim alignment + fallback in the dispatcher. + +Both axes use ``atol=rtol=0``: the 1x64 path is deterministic (no +stochastic rounding, no RHT random sign mask), so any non-zero diff is a +real regression. +""" + +import pytest +import torch + +import transformer_engine.pytorch as te +import transformer_engine_torch as tex +from transformer_engine.pytorch import NVFP4Quantizer +from transformer_engine.pytorch.custom_recipes.quantization_nvfp4_1x64 import ( + NVFP4Quantizer1x64Ref, +) + +from nvfp4_utils import ( + assert_same_shape_and_dtype, + generate_split_sections, + get_nvfp4_scale_shape_no_padding, + reference_group_quantize, +) + + +recipe_available, reason_for_no_recipe = te.is_nvfp4_available(return_reason=True) + + +# ----------------------------------------------------------------------------- +# Shape matrix +# +# 1x64 hierarchy requires both M and N to be multiples of 64. The split-quantize +# dispatcher additionally has two routing branches that we want to exercise: +# +# * N % 128 == 0 -> ``QuantizationMethod::FUSED_NVFP4`` keeps the call on the +# fused ``split_quantize_nvfp4_impl_helper`` path. +# * N % 128 != 0 -> dispatcher mirrors the ``BULK_NVFP4`` ``% 128`` fallback +# and downgrades to ``QuantizationMethod::UNFUSED`` (per-tensor +# ``NVFP4Quantizer::quantize_impl`` loop). The kernel itself is the same; +# only the driver changes. +# +# Both branches are kept here so that a regression in either lane fails CI. +# ----------------------------------------------------------------------------- +_SHAPES_FUSED_PATH = [ + (256, 1024), + (1024, 256), + (2048, 2048), +] +_SHAPES_FALLBACK_PATH = [ + (1024, 320), # N % 128 != 0, hits the dispatcher's UNFUSED fallback. +] +_SHAPES = _SHAPES_FUSED_PATH + _SHAPES_FALLBACK_PATH + + +def _unpack_fp4(x: torch.Tensor) -> torch.Tensor: + """Unpack two FP4 values per byte into one ``uint8`` per element.""" + repeated = x.repeat_interleave(2, dim=1) + repeated[:, 0::2] &= 0x0F + repeated[:, 1::2] >>= 4 + return repeated + + +def _make_quantizers(num: int, *, rowwise: bool, columnwise: bool): + """Build a list of ``NVFP4Quantizer`` configured for 1x64-compatible mode. + + 1x64 is mutually exclusive with RHT and 2D quantization; the constructor + flags here mirror that constraint so we never feed an invalid combination + into either dispatch path. + """ + return [ + NVFP4Quantizer( + fp4_dtype=tex.DType.kFloat4E2M1, + rowwise=rowwise, + columnwise=columnwise, + with_amax_reduction=False, + amax_reduction_group=None, + with_rht=False, + with_post_rht_amax=False, + with_2d_quantization=False, + ) + for _ in range(num) + ] + + +def _make_input(M: int, N: int, dtype: torch.dtype) -> torch.Tensor: + torch.manual_seed(0) + torch.cuda.manual_seed(0) + return torch.randn((M, N), dtype=dtype, device="cuda") + + +# ----------------------------------------------------------------------------- +# Axis A: algorithm oracle. tex.split_quantize vs NVFP4Quantizer1x64Ref +# ----------------------------------------------------------------------------- +def _check_axis_a_against_oracle( + *, + M: int, + N: int, + dtype: torch.dtype, + split_sections: list[int], + rowwise: bool, + columnwise: bool, +) -> None: + x = _make_input(M, N, dtype) + chunks = torch.split(x, split_sections) + quantizers = _make_quantizers(len(split_sections), rowwise=rowwise, columnwise=columnwise) + + sut_outputs = tex.split_quantize(x, split_sections, quantizers) + + for i, (sut, chunk) in enumerate(zip(sut_outputs, chunks)): + if split_sections[i] == 0: + # Empty chunk: SUT just allocates zero-row buffers. The oracle's + # behaviour on an empty input is irrelevant -- we only need to + # know the kernel did not write past the chunk boundary. + if rowwise: + assert sut._rowwise_data.shape[0] == 0, f"chunk {i} rowwise data not empty" + assert sut._amax_rowwise.shape[0] == 0, f"chunk {i} rowwise amax not empty" + if columnwise: + # columnwise tensors are transposed: (N, M_chunk//2) etc. + assert sut._columnwise_data.shape[1] == 0, ( + f"chunk {i} columnwise data not empty" + ) + assert sut._amax_columnwise.shape[1] == 0, ( + f"chunk {i} columnwise amax not empty" + ) + continue + + ref = NVFP4Quantizer1x64Ref(rowwise=rowwise, columnwise=columnwise).quantize(chunk) + + if rowwise: + # Kernel pads qx/sx to alignment boundaries; only compare the + # un-padded prefix, exactly as test_nvfp4_1x64_quantize_exact.py + # does for the single-tensor path. + qx_sut = _unpack_fp4(sut._rowwise_data.view(dtype=torch.uint8)) + qx_ref = _unpack_fp4(ref.data.view(dtype=torch.uint8)) + sx_sut = sut._rowwise_scale_inv.view(dtype=torch.uint8) + sx_ref = ref.scale.view(dtype=torch.uint8) + torch.testing.assert_close( + qx_sut[: qx_ref.shape[0], : qx_ref.shape[1]], qx_ref, atol=0.0, rtol=0.0 + ) + torch.testing.assert_close( + sx_sut[: sx_ref.shape[0], : sx_ref.shape[1]], sx_ref, atol=0.0, rtol=0.0 + ) + torch.testing.assert_close( + sut._amax_rowwise, ref.window_amax_row, atol=0.0, rtol=0.0 + ) + + if columnwise: + qxt_sut = _unpack_fp4(sut._columnwise_data.view(dtype=torch.uint8)) + qxt_ref = _unpack_fp4(ref.columnwise_data.view(dtype=torch.uint8)) + sxt_sut = sut._columnwise_scale_inv.view(dtype=torch.uint8) + sxt_ref = ref.columnwise_scale.view(dtype=torch.uint8) + torch.testing.assert_close( + qxt_sut[: qxt_ref.shape[0], : qxt_ref.shape[1]], qxt_ref, atol=0.0, rtol=0.0 + ) + torch.testing.assert_close( + sxt_sut[: sxt_ref.shape[0], : sxt_ref.shape[1]], sxt_ref, atol=0.0, rtol=0.0 + ) + torch.testing.assert_close( + sut._amax_columnwise, ref.window_amax_col, atol=0.0, rtol=0.0 + ) + + +# ----------------------------------------------------------------------------- +# Axis B: wiring oracle. tex.split_quantize vs per-chunk quantizer(chunk) +# ----------------------------------------------------------------------------- +def _check_axis_b_against_per_tensor( + *, + M: int, + N: int, + dtype: torch.dtype, + split_sections: list[int], + rowwise: bool, + columnwise: bool, +) -> None: + x = _make_input(M, N, dtype) + chunks = torch.split(x, split_sections) + + # Build TWO independent quantizer lists -- one drives the SUT, one drives + # the per-tensor reference -- to keep the comparison strictly between the + # two C++ dispatch paths and not accidentally reuse any quantizer state. + sut_quantizers = _make_quantizers(len(split_sections), rowwise=rowwise, columnwise=columnwise) + ref_quantizers = _make_quantizers(len(split_sections), rowwise=rowwise, columnwise=columnwise) + + sut_outputs = tex.split_quantize(x, split_sections, sut_quantizers) + + # ``reference_group_quantize`` calls ``ref_quantizers[i](chunk)`` per + # chunk -- with the 1x64 env flag on, this still runs the same + # ``quantize_1x64_local_encode`` kernel, just driven through the + # single-tensor entry point. SUT and reference must agree bit-for-bit + # on every byte the kernel actually writes. + qx_ref, sx_ref, amax_row_ref, qxt_ref, sxt_ref, amax_col_ref = reference_group_quantize( + x, ref_quantizers, split_sections, rowwise, columnwise + ) + + # NVFP4 scale buffers are over-allocated (rounded up to the cuBLAS + # block-scaling-factor layout: 128 in the outer dim, 4 in the inner dim) + # so that swizzle can run in place. The kernel only writes the un-padded + # prefix; the padded tail is left in whatever state ``at::empty`` + # returned, which differs across allocations. Slice both sides down to + # the valid prefix before doing a bit-exact compare. Data buffers + # ``(M, N/2)`` and 1x64 amax buffers ``(M, N/64)`` / ``(N, M/64)`` are + # already exact-sized, so they can be compared whole. + if rowwise: + sut_qx = [out._rowwise_data.view(dtype=torch.uint8) for out in sut_outputs] + sut_sx = [out._rowwise_scale_inv for out in sut_outputs] + sut_amax = [out._amax_rowwise for out in sut_outputs] + for i in range(len(sut_outputs)): + if split_sections[i] == 0: + assert_same_shape_and_dtype(sut_qx[i], qx_ref[i]) + assert_same_shape_and_dtype(sut_sx[i], sx_ref[i]) + assert_same_shape_and_dtype(sut_amax[i], amax_row_ref[i]) + continue + torch.testing.assert_close(sut_qx[i], qx_ref[i], atol=0.0, rtol=0.0) + torch.testing.assert_close(sut_amax[i], amax_row_ref[i], atol=0.0, rtol=0.0) + valid = get_nvfp4_scale_shape_no_padding(chunks[i].shape, columnwise=False) + torch.testing.assert_close( + sut_sx[i][: valid[0], : valid[1]], + sx_ref[i][: valid[0], : valid[1]], + atol=0.0, + rtol=0.0, + ) + + if columnwise: + sut_qxt = [out._columnwise_data.view(dtype=torch.uint8) for out in sut_outputs] + sut_sxt = [out._columnwise_scale_inv for out in sut_outputs] + sut_amax_t = [out._amax_columnwise for out in sut_outputs] + for i in range(len(sut_outputs)): + if split_sections[i] == 0: + assert_same_shape_and_dtype(sut_qxt[i], qxt_ref[i]) + assert_same_shape_and_dtype(sut_sxt[i], sxt_ref[i]) + assert_same_shape_and_dtype(sut_amax_t[i], amax_col_ref[i]) + continue + torch.testing.assert_close(sut_qxt[i], qxt_ref[i], atol=0.0, rtol=0.0) + torch.testing.assert_close(sut_amax_t[i], amax_col_ref[i], atol=0.0, rtol=0.0) + valid = get_nvfp4_scale_shape_no_padding(chunks[i].shape, columnwise=True) + torch.testing.assert_close( + sut_sxt[i][: valid[0], : valid[1]], + sxt_ref[i][: valid[0], : valid[1]], + atol=0.0, + rtol=0.0, + ) + + +# ----------------------------------------------------------------------------- +# Pytest entry points +# ----------------------------------------------------------------------------- +@pytest.fixture +def _enable_1x64(monkeypatch): + """Enable 1x64 local-encode for both the SUT and the reference paths. + + The flag is read at every quantize call (it gates ``local_encode_from_env`` + in ``nvfp4_1x64.h``), so it must be set before any tensor flows through + either ``tex.split_quantize`` or ``quantizer(chunk)``. ``monkeypatch`` + scopes the change to a single test, so other tests in the same session + are unaffected. + """ + monkeypatch.setenv("NVTE_NVFP4_ROWWISE_1X64_LOCAL_ENCODE", "1") + monkeypatch.setenv("NVTE_NVFP4_DISABLE_RHT", "1") + + +@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe) +@pytest.mark.parametrize("M, N", _SHAPES) +@pytest.mark.parametrize( + "edge_cases", + [ + "regular", + "zero_tokens_front", + "zero_tokens_middle", + "random_uneven_split", + ], +) +@pytest.mark.parametrize( + "rowwise, columnwise", + [(True, False), (False, True), (True, True)], + ids=["row", "col", "rowcol"], +) +@pytest.mark.parametrize("x_dtype", [torch.bfloat16], ids=str) +def test_split_quantize_1x64_axis_a_oracle( + _enable_1x64, + x_dtype: torch.dtype, + M: int, + N: int, + edge_cases: str, + rowwise: bool, + columnwise: bool, +) -> None: + """Axis A: split-quantize output must equal the pure-PyTorch 1x64 oracle.""" + split_sections = generate_split_sections(M, N, edge_cases, least_multiple=64) + _check_axis_a_against_oracle( + M=M, + N=N, + dtype=x_dtype, + split_sections=split_sections, + rowwise=rowwise, + columnwise=columnwise, + ) + + +@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe) +@pytest.mark.parametrize("M, N", _SHAPES) +@pytest.mark.parametrize( + "edge_cases", + [ + "regular", + "zero_tokens_front", + "zero_tokens_middle", + "random_uneven_split", + ], +) +@pytest.mark.parametrize( + "rowwise, columnwise", + [(True, False), (False, True), (True, True)], + ids=["row", "col", "rowcol"], +) +@pytest.mark.parametrize("x_dtype", [torch.bfloat16], ids=str) +def test_split_quantize_1x64_axis_b_wiring( + _enable_1x64, + x_dtype: torch.dtype, + M: int, + N: int, + edge_cases: str, + rowwise: bool, + columnwise: bool, +) -> None: + """Axis B: split-quantize and per-tensor driver must produce identical bytes.""" + split_sections = generate_split_sections(M, N, edge_cases, least_multiple=64) + _check_axis_b_against_per_tensor( + M=M, + N=N, + dtype=x_dtype, + split_sections=split_sections, + rowwise=rowwise, + columnwise=columnwise, + ) diff --git a/transformer_engine/pytorch/csrc/extensions/cast.cpp b/transformer_engine/pytorch/csrc/extensions/cast.cpp index ee0cdb61dc..9f45fd7822 100644 --- a/transformer_engine/pytorch/csrc/extensions/cast.cpp +++ b/transformer_engine/pytorch/csrc/extensions/cast.cpp @@ -1185,39 +1185,6 @@ void split_quantize_nvfp4_impl_with_rht_helper(const TensorWrapper &input, } } -/// Non-RHT split: grouped amax over splits, then optional D2D copy rowwise amax -> columnwise amax. -static void split_nvfp4_non_rht_run_grouped_amax( - const TensorWrapper &input, std::vector &output_list, - std::vector &nvte_tensor_output_list, const std::vector &split_sections, - size_t num_tensors, bool copy_colwise_amax_from_rowwise, cudaStream_t stream) { - std::vector orig_amax_ptr_list; - orig_amax_ptr_list.reserve(num_tensors); - for (size_t i = 0; i < num_tensors; i++) { - auto rowwise_amax_ptr = output_list[i].get_amax().data_ptr; - orig_amax_ptr_list.push_back(rowwise_amax_ptr); - auto columnwise_amax_ptr = output_list[i].get_columnwise_amax().data_ptr; - void *amax_ptr = rowwise_amax_ptr != nullptr ? rowwise_amax_ptr : columnwise_amax_ptr; - NVTE_CHECK(amax_ptr != nullptr, "Could not find amax pointer"); - output_list[i].set_amax(amax_ptr, DType::kFloat32, std::vector{1}); - } - nvte_group_amax(input.data(), reinterpret_cast(nvte_tensor_output_list.data()), - split_sections.data(), num_tensors, stream); - for (size_t i = 0; i < num_tensors; i++) { - output_list[i].set_amax(orig_amax_ptr_list[i], DType::kFloat32, std::vector{1}); - } - if (copy_colwise_amax_from_rowwise) { - for (size_t i = 0; i < num_tensors; i++) { - auto rowwise_amax_ptr = output_list[i].get_amax().data_ptr; - auto colwise_amax_ptr = output_list[i].get_columnwise_amax().data_ptr; - if (rowwise_amax_ptr != nullptr && colwise_amax_ptr != nullptr && - rowwise_amax_ptr != colwise_amax_ptr) { - NVTE_CHECK_CUDA(cudaMemcpyAsync(colwise_amax_ptr, rowwise_amax_ptr, sizeof(float), - cudaMemcpyDeviceToDevice, stream)); - } - } - } -} - void split_quantize_nvfp4_impl_helper(const TensorWrapper &input, const std::vector &input_list, std::vector &output_list, @@ -1232,31 +1199,34 @@ void split_quantize_nvfp4_impl_helper(const TensorWrapper &input, nvte_tensor_output_list.push_back(output_list[i].data()); } + // In this case without RHT, the rowwise and colwise quantization are fused + // we don't need separate rng states for rowwise and colwise + bool need_separate_rng_states = false; + // Objects for TE C API std::vector quant_config_list; for (size_t i = 0; i < num_tensors; ++i) { quant_config_list.emplace_back(QuantizationConfigWrapper()); } - bool with_bulk_generate_rng_states = true; - bool need_separate_rng_states = false; + // Per-tensor quantize: each tensor gets its own kernel launch, so RNG states + // are advanced once per tensor on the host. + bool with_bulk_generate_rng_states = false; + bool need_stochastic_rounding = quantizer.stochastic_rounding; - std::vector quant_config_list_colwise; - for (size_t i = 0; i < num_tensors; ++i) { - quant_config_list_colwise.emplace_back(QuantizationConfigWrapper()); - } + + // place holder for colwise rng states, which are not needed in this case + std::vector dummy_quant_config_list_colwise; + auto stochastic_rng_state_resources = setup_stochastic_rounding_rng_states_helper( num_tensors, need_stochastic_rounding, with_bulk_generate_rng_states, - need_separate_rng_states, quant_config_list, quant_config_list_colwise); + need_separate_rng_states, quant_config_list, + dummy_quant_config_list_colwise); // colwise rng states are not needed in this case - const auto use_fast_math = transformer_engine::getenv("NVTE_USE_FAST_MATH"); - if (use_fast_math) { - for (auto &config : quant_config_list) { - config.set_use_fast_math(true); - } - } - - // 1x64: per-tensor nvte_quantize_v2 (see quantize.cuh), not the grouped amax+kernel. + // Optional NVFP4 1x64 local-encode: per-window amax is computed inside the + // kernel (see quantize.cuh), so the global per-tensor amax pre-pass is + // skipped. Otherwise the per-tensor nvte_quantize_v2 path remains identical + // to the non-1x64 baseline. const bool use_rowwise_1x64 = nvfp4_1x64::local_encode_from_env(); for (size_t i = 0; i < num_tensors; ++i) { nvfp4_1x64::config_apply(quant_config_list[i], quantizer.with_2d_quantization, @@ -1266,32 +1236,35 @@ void split_quantize_nvfp4_impl_helper(const TensorWrapper &input, quantizer.stochastic_rounding); if (!use_rowwise_1x64) { - split_nvfp4_non_rht_run_grouped_amax( - input, output_list, nvte_tensor_output_list, split_sections, num_tensors, - quantizer.rowwise_usage && quantizer.columnwise_usage, stream); - } - - if (use_rowwise_1x64 && quantizer.rowwise_usage) { + // We need: + // 1. Rowwise amax = amax for input + // 2. Columnwise amax = amax for input too + // Columnwise amax will be filled with a fused D2D copy from rowwise amax + // Note that the multi compute amax API expects rowwise amax pointer to be not null + // So we need to set the pointer accordingly to make colwise-only quantization work + std::vector orig_amax_ptr_list; for (size_t i = 0; i < num_tensors; i++) { - if (input_list[i].numel() == 0) { - continue; - } - nvte_quantize_v2(input_list[i].data(), output_list[i].data(), quant_config_list[i], stream); + auto rowwise_amax_ptr = output_list[i].get_amax().data_ptr; + orig_amax_ptr_list.push_back(rowwise_amax_ptr); + auto columnwise_amax_ptr = output_list[i].get_columnwise_amax().data_ptr; + void *amax_ptr = rowwise_amax_ptr != nullptr ? rowwise_amax_ptr : columnwise_amax_ptr; + NVTE_CHECK(amax_ptr != nullptr, "Could not find amax pointer"); + output_list[i].set_amax(amax_ptr, DType::kFloat32, std::vector{1}); } - } else if (quantizer.rowwise_usage) { - // Grouped rowwise (+ columnwise if output tensors carry colwise buffers) - // in a single kernel launch. When the output has columnwise data, the kernel - // template parameter RETURN_TRANSPOSE=true enables the colwise write-back path. - nvte_group_nvfp4_quantize_with_amax(input.data(), nvte_tensor_output_list.data(), - split_sections.data(), num_tensors, quant_config_list[0], - stream); - } else if (quantizer.columnwise_usage) { - // Colwise-only: the grouped kernel requires a contiguous rowwise output - // buffer for TMA, so fall back to per-tensor quantization. + nvte_group_amax(input.data(), reinterpret_cast(nvte_tensor_output_list.data()), + split_sections.data(), num_tensors, stream); for (size_t i = 0; i < num_tensors; i++) { - if (input_list[i].numel() == 0) continue; - nvte_quantize_v2(input_list[i].data(), output_list[i].data(), quant_config_list[i], stream); + output_list[i].set_amax(orig_amax_ptr_list[i], DType::kFloat32, std::vector{1}); + } + } + + // Quantize tensors individually + for (size_t i = 0; i < num_tensors; i++) { + // skip this round if input is empty + if (input_list[i].numel() == 0) { + continue; } + nvte_quantize_v2(input_list[i].data(), output_list[i].data(), quant_config_list[i], stream); } } @@ -1431,8 +1404,24 @@ std::vector split_quantize(const at::Tensor &tensor, [](const py::handle &quantizer) -> bool { return detail::IsNVFP4Quantizers(quantizer.ptr()); })) { - allocation_method = AllocationMethod::BULK_NVFP4; - quantization_method = QuantizationMethod::FUSED_NVFP4; + // bulk_allocate_nvfp4_tensors only knows the stock (1,) amax layout; the + // 1x64 kernel needs a per-window amax buffer of shape (M_chunk, N/64). + // When 1x64 is requested, fall back to per-tensor allocation + // (NVFP4Quantizer::create_tensor is 1x64-aware) but keep the fused + // quantize path so split_quantize_nvfp4_impl_helper still drives the + // kernel. Mirror the BULK_NVFP4 case's `% 128 != 0` fallback so + // non-aligned inner dims still go through the unfused per-tensor path + // (split_quantize_nvfp4_impl asserts inner dim is a multiple of 128). + if (nvfp4_1x64::local_encode_from_env()) { + if (!input_shape.empty() && input_shape.back() % 128 != 0) { + quantization_method = QuantizationMethod::UNFUSED; + } else { + quantization_method = QuantizationMethod::FUSED_NVFP4; + } + } else { + allocation_method = AllocationMethod::BULK_NVFP4; + quantization_method = QuantizationMethod::FUSED_NVFP4; + } } } From 68fdbd61b0b5d69d0362f42673cb157967fd12e8 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 29 Apr 2026 09:11:49 +0000 Subject: [PATCH 10/10] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../pytorch/nvfp4/test_nvfp4_1x64_split_quantize.py | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/tests/pytorch/nvfp4/test_nvfp4_1x64_split_quantize.py b/tests/pytorch/nvfp4/test_nvfp4_1x64_split_quantize.py index 6ef744f309..e5bdc4b1dc 100644 --- a/tests/pytorch/nvfp4/test_nvfp4_1x64_split_quantize.py +++ b/tests/pytorch/nvfp4/test_nvfp4_1x64_split_quantize.py @@ -142,12 +142,8 @@ def _check_axis_a_against_oracle( assert sut._amax_rowwise.shape[0] == 0, f"chunk {i} rowwise amax not empty" if columnwise: # columnwise tensors are transposed: (N, M_chunk//2) etc. - assert sut._columnwise_data.shape[1] == 0, ( - f"chunk {i} columnwise data not empty" - ) - assert sut._amax_columnwise.shape[1] == 0, ( - f"chunk {i} columnwise amax not empty" - ) + assert sut._columnwise_data.shape[1] == 0, f"chunk {i} columnwise data not empty" + assert sut._amax_columnwise.shape[1] == 0, f"chunk {i} columnwise amax not empty" continue ref = NVFP4Quantizer1x64Ref(rowwise=rowwise, columnwise=columnwise).quantize(chunk) @@ -166,9 +162,7 @@ def _check_axis_a_against_oracle( torch.testing.assert_close( sx_sut[: sx_ref.shape[0], : sx_ref.shape[1]], sx_ref, atol=0.0, rtol=0.0 ) - torch.testing.assert_close( - sut._amax_rowwise, ref.window_amax_row, atol=0.0, rtol=0.0 - ) + torch.testing.assert_close(sut._amax_rowwise, ref.window_amax_row, atol=0.0, rtol=0.0) if columnwise: qxt_sut = _unpack_fp4(sut._columnwise_data.view(dtype=torch.uint8))