From a35de590341a6e181972ffed27e4395ad05c540d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E4=B9=99=E5=88=92?= Date: Tue, 10 Feb 2026 16:57:38 +0800 Subject: [PATCH 1/2] fix nvfp4 convert_and_update_tensor shape check MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: 乙划 --- transformer_engine/pytorch/csrc/quantizer.cpp | 20 ++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index 1c968e276d..6a75c5559b 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -1260,6 +1260,22 @@ std::pair NVFP4Quantizer::create_unquantized_tensor_w return {std::move(out_cpp), std::move(out_py)}; } +std::vector compressShapeTo2D(const std::vector& data) { + // If 2 or fewer elements, return as-is + if (data.size() <= 2) { + return data; + } + // Multiply all elements except the last + size_t product = std::accumulate( + data.begin(), + data.end() - 1, + static_cast(1), + std::multiplies() + ); + // Return new vector of size 2: {product, last} + return std::vector{ product, data.back() }; +} + std::pair NVFP4Quantizer::convert_and_update_tensor( py::object tensor) const { NVTE_CHECK(detail::IsNVFP4Tensor(tensor.ptr()), "NVFP4Quantizer must output to IsNVFP4Tensor."); @@ -1289,8 +1305,10 @@ std::pair NVFP4Quantizer::convert_and_update_tensor( shape = convert_shape_back_from_fp4(getTensorShape(*columnwise_data), true); if (rowwise_data) { auto expected_shape = convert_shape_back_from_fp4(getTensorShape(*rowwise_data), false); - NVTE_CHECK(shape == expected_shape, "NVFP4 row-wise data (shape=", expected_shape, + auto expected_shape_2d = compressShapeTo2D(expected_shape); + NVTE_CHECK(shape == expected_shape_2d, "NVFP4 row-wise data (2D shape=", expected_shape_2d, ") and column-wise data (shape=", shape, ") do not match"); + shape = expected_shape; } } else { // Already checked columnwise_data_tensor == true shape = convert_shape_back_from_fp4(getTensorShape(*rowwise_data), false); From c7eaa0571127800066b03b2660f333e8144ed216 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 10 Feb 2026 09:07:39 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/pytorch/csrc/quantizer.cpp | 22 ++++++++----------- 1 file changed, 9 insertions(+), 13 deletions(-) diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index 6a75c5559b..2a13a01dbb 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -1261,19 +1261,15 @@ std::pair NVFP4Quantizer::create_unquantized_tensor_w } std::vector compressShapeTo2D(const std::vector& data) { - // If 2 or fewer elements, return as-is - if (data.size() <= 2) { - return data; - } - // Multiply all elements except the last - size_t product = std::accumulate( - data.begin(), - data.end() - 1, - static_cast(1), - std::multiplies() - ); - // Return new vector of size 2: {product, last} - return std::vector{ product, data.back() }; + // If 2 or fewer elements, return as-is + if (data.size() <= 2) { + return data; + } + // Multiply all elements except the last + size_t product = std::accumulate(data.begin(), data.end() - 1, static_cast(1), + std::multiplies()); + // Return new vector of size 2: {product, last} + return std::vector{product, data.back()}; } std::pair NVFP4Quantizer::convert_and_update_tensor(