diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index 1c968e276d..2a13a01dbb 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -1260,6 +1260,18 @@ std::pair NVFP4Quantizer::create_unquantized_tensor_w return {std::move(out_cpp), std::move(out_py)}; } +std::vector compressShapeTo2D(const std::vector& data) { + // If 2 or fewer elements, return as-is + if (data.size() <= 2) { + return data; + } + // Multiply all elements except the last + size_t product = std::accumulate(data.begin(), data.end() - 1, static_cast(1), + std::multiplies()); + // Return new vector of size 2: {product, last} + return std::vector{product, data.back()}; +} + std::pair NVFP4Quantizer::convert_and_update_tensor( py::object tensor) const { NVTE_CHECK(detail::IsNVFP4Tensor(tensor.ptr()), "NVFP4Quantizer must output to IsNVFP4Tensor."); @@ -1289,8 +1301,10 @@ std::pair NVFP4Quantizer::convert_and_update_tensor( shape = convert_shape_back_from_fp4(getTensorShape(*columnwise_data), true); if (rowwise_data) { auto expected_shape = convert_shape_back_from_fp4(getTensorShape(*rowwise_data), false); - NVTE_CHECK(shape == expected_shape, "NVFP4 row-wise data (shape=", expected_shape, + auto expected_shape_2d = compressShapeTo2D(expected_shape); + NVTE_CHECK(shape == expected_shape_2d, "NVFP4 row-wise data (2D shape=", expected_shape_2d, ") and column-wise data (shape=", shape, ") do not match"); + shape = expected_shape; } } else { // Already checked columnwise_data_tensor == true shape = convert_shape_back_from_fp4(getTensorShape(*rowwise_data), false);