diff --git a/docs/envvars.rst b/docs/envvars.rst index 1e040b4c3e..665f7912ab 100644 --- a/docs/envvars.rst +++ b/docs/envvars.rst @@ -281,6 +281,12 @@ Kernel Configuration :Default: ``0`` :Description: Emit a warning when falling back from CUTLASS to cuBLAS for grouped GEMM operations. +.. envvar:: NVTE_NVFP4_ROW_SCALED_ACTIVATION + + :Type: ``int`` (0 or 1) + :Default: ``0`` + :Description: Enable row-scaled NVFP4 tensors for forward activation quantizers in the ``NVFP4BlockScaling`` recipe. When set to ``1`` (or when ``NVFP4BlockScaling(row_scaled_activation=True)`` is used), rowwise ``amax`` metadata is stored as one FP32 value per tensor row instead of a single scalar. + Torch Compilation and Fusion ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ diff --git a/tests/cpp/operator/test_cast_nvfp4_transpose.cu b/tests/cpp/operator/test_cast_nvfp4_transpose.cu index 15d7c695c9..1f37520bc7 100644 --- a/tests/cpp/operator/test_cast_nvfp4_transpose.cu +++ b/tests/cpp/operator/test_cast_nvfp4_transpose.cu @@ -114,16 +114,14 @@ void quantize_nvfp4_1d(float (*OP)(const float), block_amax = std::max(block_amax, std::abs(elt)); } - // 2. Compute E4M3 scaling factor - // Compute per-block encoding/decoding scaling factor - const float S_dec_b = block_amax / 6.0f; - - // Scale & Store per-block decoding scaling factor - const fp8e4m3 S_dec_b_fp8 = static_cast(S_dec_b * S_enc); + // Compute and store the per-block FP8 decode scale + const float S_dec_b = block_amax * (S_enc * (1.0f / 6.0f)); + const fp8e4m3 S_dec_b_fp8 = static_cast(fminf(S_dec_b, Numeric_Traits::maxNorm)); const float S_dec_b_fp32 = static_cast(S_dec_b_fp8); // Compute "correct" per-block encoding scaling factor - const float S_enc_b_fp8 = S_dec_b_fp32 == 0.f ? 0.f : S_enc / S_dec_b_fp32; + const float S_enc_b_fp8 = S_dec_b_fp32 == 0.f ? 0.f : + fminf(1.0f / (S_dec_b_fp32 * (1.0f / S_enc)), Numeric_Traits::maxNorm); const size_t scale_idx = i * scales_stride + block_X; scales[scale_idx] = S_dec_b_fp8; @@ -317,11 +315,31 @@ void compute_ref(float (*OP)(const float), const size_t scales_stride, const size_t scales_stride_t, const bool use_fast_math, - const bool use_2d_quantization = false) + const bool use_2d_quantization = false, + std::vector *rowwise_amax = nullptr) { std::vector input_t = create_transpose(input, rows, cols); - if (use_2d_quantization) { + if (rowwise_amax != nullptr) { + rowwise_amax->resize(rows, 0.0f); + for (size_t row = 0; row < rows; ++row) { + float row_amax = 0.0f; + for (size_t col = 0; col < cols; ++col) { + row_amax = fmaxf(row_amax, fabsf(static_cast(input[row * cols + col]))); + } + (*rowwise_amax)[row] = row_amax; + quantize_nvfp4(OP, + input + row * cols, + output + row * (cols / 2), + scales + row * scales_stride, + 1, + cols, + scales_stride, + row_amax, + use_fast_math, + use_2d_quantization); + } + } else if (use_2d_quantization) { // Step 1: Compute mathematical 8×8 scaling factors std::vector> math_scales; compute_2d_mathematical_scales(OP, input, rows, cols, global_amax, math_scales, use_fast_math); @@ -504,13 +522,12 @@ void print_detailed_tensor_comparison(const std::string& name, void compareResults_nvfp4(const Tensor &test, const void *ref, const void *ref_t, const int rows, const int cols, - double atol = 1e-5, double rtol = 1e-8, bool if_on_gpus = true, bool dump_data = false) { + double atol = 1e-5, double rtol = 1e-8, bool if_on_gpus = true, + bool dump_data = false, bool compare_columnwise = true) { if (if_on_gpus) test.to_cpu(); const fp4e2m1 *test_data = test.rowwise_cpu_dptr(); - const fp4e2m1 *test_data_t = test.columnwise_cpu_dptr(); const fp4e2m1 *ref_data = reinterpret_cast(ref); - const fp4e2m1 *ref_data_t = reinterpret_cast(ref_t); // Print detailed element-by-element comparison // print_detailed_tensor_comparison("output", test_data, ref_data, rows, cols); @@ -519,17 +536,33 @@ void compareResults_nvfp4(const Tensor &test, // Optionally dump tensor data to files for detailed analysis if (dump_data) { dump_nvfp4_tensor_data("output", test_data, ref_data, rows, cols); - dump_nvfp4_tensor_data("output_t", test_data_t, ref_data_t, cols, rows); } compare_nvfp4_tensors("output", test_data, ref_data, rows, cols, atol, rtol); - compare_nvfp4_tensors("output_t", test_data_t, ref_data_t, cols, rows, atol, rtol); + if (compare_columnwise) { + const fp4e2m1 *test_data_t = test.columnwise_cpu_dptr(); + const fp4e2m1 *ref_data_t = reinterpret_cast(ref_t); + if (dump_data) { + dump_nvfp4_tensor_data("output_t", test_data_t, ref_data_t, cols, rows); + } + compare_nvfp4_tensors("output_t", test_data_t, ref_data_t, cols, rows, atol, rtol); + } +} + +void compare_rowwise_amax(const Tensor &output, const std::vector &ref_amax) { + const std::vector test_amax_data = output.tensor_amax_values(); + ASSERT_EQ(test_amax_data.size(), ref_amax.size()); + for (size_t row = 0; row < ref_amax.size(); ++row) { + ASSERT_EQ(test_amax_data[row], ref_amax[row]) + << "Row-scaled amax mismatch at row " << row; + } } template void performTest(float (*OP)(const float), const std::vector& shape, - const bool use_fast_math) { + const bool use_fast_math, + const bool row_scaled_nvfp4 = false) { using namespace test; DType itype = TypeInfo::dtype; @@ -556,7 +589,7 @@ void performTest(float (*OP)(const float), const size_t scales_stride_t = blocks_X_t; Tensor input("input", shape, itype); - Tensor output("output", shape, otype, true, true, NVTE_NVFP4_1D_SCALING); + Tensor output("output", shape, otype, true, !row_scaled_nvfp4, NVTE_NVFP4_1D_SCALING); std::unique_ptr ref_output = std::make_unique(rows * (cols / 2)); std::unique_ptr ref_output_t = std::make_unique(cols * (rows / 2)); @@ -567,26 +600,44 @@ void performTest(float (*OP)(const float), // Golden value of amax chosen to make the 2nd-stage scaling mantissa zero and avoid rounding issues const float amax = 448.0f * 6.0f * 8.0f; - - // Set 2nd stage NVFP4 scaling factor - output.set_tensor_amax(amax); - output.set_tensor_amax_columnwise(amax); - + std::vector ref_rowwise_amax; bool use_2d_quantization = false; + if (row_scaled_nvfp4) { + output.set_tensor_amax_shape({rows}); + output.set_row_scaled_nvfp4(true); + compute_ref(OP, + input.rowwise_cpu_dptr(), + ref_output.get(), + ref_output_t.get(), + ref_scales.get(), + ref_scales_t.get(), + 0.0f, + rows, + cols, + scales_stride, + scales_stride_t, + use_fast_math, + use_2d_quantization, + &ref_rowwise_amax); + } else { + // Set 2nd stage NVFP4 scaling factor + output.set_tensor_amax(amax); + output.set_tensor_amax_columnwise(amax); + compute_ref(OP, + input.rowwise_cpu_dptr(), + ref_output.get(), + ref_output_t.get(), + ref_scales.get(), + ref_scales_t.get(), + amax, + rows, + cols, + scales_stride, + scales_stride_t, + use_fast_math, + use_2d_quantization); + } - compute_ref(OP, - input.rowwise_cpu_dptr(), - ref_output.get(), - ref_output_t.get(), - ref_scales.get(), - ref_scales_t.get(), - amax, - rows, - cols, - scales_stride, - scales_stride_t, - use_fast_math, - use_2d_quantization); // Initialize stochastic rounding Tensor rng_state("rng_state", std::vector{2}, DType::kInt64); rng_state.rowwise_cpu_dptr()[0] = 123; // rng_seed @@ -629,12 +680,8 @@ void performTest(float (*OP)(const float), const double rtol = 1.0E-6; // Set dump_data=true to enable dumping tensor data to files for analysis - compareResults_nvfp4(output, ref_output.get(), ref_output_t.get(), rows, cols, atol, rtol, true, false); - - const fp8e4m3* kernel_scales = output.rowwise_cpu_scale_inv_ptr(); - const fp8e4m3* ref_scales_ptr = ref_scales.get(); - const fp8e4m3* kernel_scales_t = output.columnwise_cpu_scale_inv_ptr(); - const fp8e4m3* ref_scales_t_ptr = ref_scales_t.get(); + compareResults_nvfp4(output, ref_output.get(), ref_output_t.get(), rows, cols, atol, rtol, true, + false, !row_scaled_nvfp4); size_t scale_mismatches_num = 0; compare_scaling_factors("scales", output.rowwise_cpu_scale_inv_ptr(), @@ -642,10 +689,16 @@ void performTest(float (*OP)(const float), unpadded_blocks_Y, unpadded_blocks_X, scales_stride, scale_mismatches_num); - compare_scaling_factors("scales_t", output.columnwise_cpu_scale_inv_ptr(), - ref_scales_t.get(), - unpadded_blocks_Y_t, unpadded_blocks_X_t, scales_stride_t, - scale_mismatches_num); + if (!row_scaled_nvfp4) { + compare_scaling_factors("scales_t", output.columnwise_cpu_scale_inv_ptr(), + ref_scales_t.get(), + unpadded_blocks_Y_t, unpadded_blocks_X_t, scales_stride_t, + scale_mismatches_num); + } + + if (row_scaled_nvfp4) { + compare_rowwise_amax(output, ref_rowwise_amax); + } } std::vector> tensor_dims = { @@ -678,6 +731,7 @@ class FusedCastTransposeNVFP4TestSuite : public ::testing::TestWithParam , transformer_engine::DType, + bool, bool>> {}; TEST_P(FusedCastTransposeNVFP4TestSuite, TestFusedCastTransposeNVFP4) { @@ -693,6 +747,7 @@ TEST_P(FusedCastTransposeNVFP4TestSuite, TestFusedCastTransposeNVFP4) { const auto tensor_dims = std::get<1>(GetParam()); const DType input_type = std::get<2>(GetParam()); const bool use_fast_math = std::get<3>(GetParam()); + const bool row_scaled_nvfp4 = std::get<4>(GetParam()); // Skip tests if the input tensor is 1D if (tensor_dims.size() < 2) { @@ -710,7 +765,7 @@ TEST_P(FusedCastTransposeNVFP4TestSuite, TestFusedCastTransposeNVFP4) { } TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(input_type, InputType, - performTest(OP, tensor_dims, use_fast_math); + performTest(OP, tensor_dims, use_fast_math, row_scaled_nvfp4); ); } @@ -733,6 +788,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::ValuesIn(Activation_types), ::testing::ValuesIn(tensor_dims), ::testing::Values(DType::kBFloat16), + ::testing::Values(false), ::testing::Values(false)), [](const testing::TestParamInfo& info) { std::string name = to_string(std::get<0>(info.param)); @@ -746,3 +802,28 @@ INSTANTIATE_TEST_SUITE_P( } return name; }); + +INSTANTIATE_TEST_SUITE_P( + OperatorTestRowScaled, + FusedCastTransposeNVFP4TestSuite, + ::testing::Combine( + ::testing::Values(ActivationType::Identity), + ::testing::Values(tensor_dims[4], tensor_dims[9], tensor_dims[12]), + ::testing::Values(DType::kBFloat16, DType::kFloat32), + ::testing::Values(false), + ::testing::Values(true)), + [](const testing::TestParamInfo& info) { + std::string name = to_string(std::get<0>(info.param)); + const auto& shape = std::get<1>(info.param); + for (const auto& s: shape) { + name += "X" + std::to_string(s); + } + name += "X" + test::typeName(std::get<2>(info.param)); + if (std::get<3>(info.param)) { + name += "X_FAST_SCALING"; + } + if (std::get<4>(info.param)) { + name += "XROW_SCALED"; + } + return name; + }); diff --git a/tests/cpp/operator/test_dequantize_nvfp4.cu b/tests/cpp/operator/test_dequantize_nvfp4.cu index 96e85cb5ed..ec405b1d90 100644 --- a/tests/cpp/operator/test_dequantize_nvfp4.cu +++ b/tests/cpp/operator/test_dequantize_nvfp4.cu @@ -42,7 +42,7 @@ float2 cvt_fp4x2_to_float2(fp4e2m1x2 fp4_pair) { template void compute_ref_dequantize_nvfp4(const uint8_t *packed_data, const fp8e4m3 *scales, - float amax, + const std::vector &amax, OType *output, size_t rows, size_t cols, @@ -55,7 +55,8 @@ void compute_ref_dequantize_nvfp4(const uint8_t *packed_data, for (size_t row = 0; row < rows; ++row) { for (size_t block = 0; block < Mread; ++block) { const fp8e4m3 scale = scales[row * scale_stride + block]; - const float final_scale = static_cast(scale) * amax * factor_inv; + const float final_scale = + static_cast(scale) * (amax.size() == 1 ? amax[0] : amax[row]) * factor_inv; for (size_t pair_idx = 0; pair_idx < bytes_per_block; ++pair_idx) { const size_t byte_idx = @@ -88,7 +89,8 @@ float compute_amax(const test::Tensor &t, size_t rows, size_t cols) { // Quantize a high-precision input to NVFP4, then dequantize and compare // against a CPU reference computed from the quantized data. template -void performTest_dequantize_nvfp4(const size_t rows, const size_t cols) { +void performTest_dequantize_nvfp4(const size_t rows, const size_t cols, + const bool row_scaled_nvfp4) { using namespace test; DType otype = TypeInfo::dtype; @@ -97,7 +99,10 @@ void performTest_dequantize_nvfp4(const size_t rows, const size_t cols) { Tensor quantized("quantized", std::vector{rows, cols}, DType::kFloat4E2M1, true, false, NVTE_NVFP4_1D_SCALING); - if (rows > 0 && cols > 0) { + if (row_scaled_nvfp4) { + quantized.set_tensor_amax_shape({rows}); + quantized.set_row_scaled_nvfp4(true); + } else if (rows > 0 && cols > 0) { quantized.set_tensor_amax(compute_amax(input, rows, cols)); } else { quantized.set_tensor_amax(0.0f); @@ -120,7 +125,7 @@ void performTest_dequantize_nvfp4(const size_t rows, const size_t cols) { const uint8_t *fp4_data = reinterpret_cast(quantized.rowwise_cpu_dptr()); const fp8e4m3 *scales = quantized.rowwise_cpu_scale_inv_ptr(); - const float amax_val = quantized.amax(); + const std::vector amax_val = quantized.tensor_amax_values(); const NVTEShape scale_shape = quantized.rowwise_scale_inv_shape(); const size_t scale_stride = scale_shape.data[scale_shape.ndim - 1]; @@ -137,7 +142,8 @@ void performTest_dequantize_nvfp4(const size_t rows, const size_t cols) { // Dequantize NVFP4 with GEMM-swizzled scales and compare against compact path. template -void performTest_dequantize_nvfp4_swizzled(const size_t rows, const size_t cols) { +void performTest_dequantize_nvfp4_swizzled(const size_t rows, const size_t cols, + const bool row_scaled_nvfp4) { using namespace test; DType otype = TypeInfo::dtype; @@ -146,7 +152,10 @@ void performTest_dequantize_nvfp4_swizzled(const size_t rows, const size_t cols) Tensor quantized_compact("quantized_compact", std::vector{rows, cols}, DType::kFloat4E2M1, true, false, NVTE_NVFP4_1D_SCALING); - if (rows > 0 && cols > 0) { + if (row_scaled_nvfp4) { + quantized_compact.set_tensor_amax_shape({rows}); + quantized_compact.set_row_scaled_nvfp4(true); + } else if (rows > 0 && cols > 0) { quantized_compact.set_tensor_amax(compute_amax(input, rows, cols)); } else { quantized_compact.set_tensor_amax(0.0f); @@ -157,7 +166,7 @@ void performTest_dequantize_nvfp4_swizzled(const size_t rows, const size_t cols) cudaDeviceSynchronize(); } - // Dequantize with compact scales → reference output + // Dequantize with compact scales to get the reference output. Tensor output_compact("output_compact", std::vector{rows, cols}, otype, true, false); nvte_dequantize(quantized_compact.data(), output_compact.data(), 0); cudaDeviceSynchronize(); @@ -165,13 +174,22 @@ void performTest_dequantize_nvfp4_swizzled(const size_t rows, const size_t cols) // Create tensor with same FP4 data but swizzled scales Tensor quantized_swizzled("quantized_swizzled", std::vector{rows, cols}, DType::kFloat4E2M1, true, false, NVTE_NVFP4_1D_SCALING); - quantized_swizzled.set_tensor_amax(0.0f); + if (row_scaled_nvfp4) { + quantized_swizzled.set_tensor_amax_shape({rows}); + quantized_swizzled.set_row_scaled_nvfp4(true); + } else { + quantized_swizzled.set_tensor_amax(0.0f); + } quantized_swizzled.set_with_gemm_swizzled_scales(true); // Copy amax and scale from compact to swizzled before FP4 data, // since from_cpu() uploads all CPU buffers (including zero-init data). quantized_compact.to_cpu(); - quantized_swizzled.set_tensor_amax(quantized_compact.amax()); + if (row_scaled_nvfp4) { + quantized_swizzled.copy_tensor_amax_from(quantized_compact); + } else { + quantized_swizzled.set_tensor_amax(quantized_compact.amax()); + } // Copy FP4 data after from_cpu() to avoid being overwritten const size_t data_bytes = rows * cols / 2; @@ -227,7 +245,8 @@ std::vector> nvfp4_tensor_dims = { class DequantizeNVFP4TestSuite : public ::testing::TestWithParam , - transformer_engine::DType>> {}; + transformer_engine::DType, + bool>> {}; TEST_P(DequantizeNVFP4TestSuite, TestDequantizeNVFP4) { @@ -237,10 +256,11 @@ TEST_P(DequantizeNVFP4TestSuite, TestDequantizeNVFP4) const auto tensor_size = std::get<0>(GetParam()); const DType output_type = std::get<1>(GetParam()); + const bool row_scaled_nvfp4 = std::get<2>(GetParam()); TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(output_type, OutputType, performTest_dequantize_nvfp4( - tensor_size.first, tensor_size.second); + tensor_size.first, tensor_size.second, row_scaled_nvfp4); ); } @@ -249,19 +269,22 @@ INSTANTIATE_TEST_SUITE_P( DequantizeNVFP4TestSuite, ::testing::Combine( ::testing::ValuesIn(nvfp4_tensor_dims), - ::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16)), + ::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16), + ::testing::Bool()), [](const testing::TestParamInfo& info) { std::string name = std::to_string(std::get<0>(info.param).first) + "X" + std::to_string(std::get<0>(info.param).second) + "X" + - test::typeName(std::get<1>(info.param)); + test::typeName(std::get<1>(info.param)) + "X" + + (std::get<2>(info.param) ? "RowScaled" : "PerTensor"); return name; } ); class DequantizeNVFP4SwizzledTestSuite : public ::testing::TestWithParam , - transformer_engine::DType>> {}; + transformer_engine::DType, + bool>> {}; TEST_P(DequantizeNVFP4SwizzledTestSuite, TestDequantizeNVFP4Swizzled) { @@ -271,10 +294,11 @@ TEST_P(DequantizeNVFP4SwizzledTestSuite, TestDequantizeNVFP4Swizzled) const auto tensor_size = std::get<0>(GetParam()); const DType output_type = std::get<1>(GetParam()); + const bool row_scaled_nvfp4 = std::get<2>(GetParam()); TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(output_type, OutputType, performTest_dequantize_nvfp4_swizzled( - tensor_size.first, tensor_size.second); + tensor_size.first, tensor_size.second, row_scaled_nvfp4); ); } @@ -283,12 +307,14 @@ INSTANTIATE_TEST_SUITE_P( DequantizeNVFP4SwizzledTestSuite, ::testing::Combine( ::testing::ValuesIn(nvfp4_tensor_dims), - ::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16)), + ::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16), + ::testing::Bool()), [](const testing::TestParamInfo& info) { std::string name = std::to_string(std::get<0>(info.param).first) + "X" + std::to_string(std::get<0>(info.param).second) + "X" + test::typeName(std::get<1>(info.param)) + "X" + + (std::get<2>(info.param) ? "RowScaled" : "PerTensor") + "X" + "Swizzled"; return name; } diff --git a/tests/cpp/test_common.cu b/tests/cpp/test_common.cu index c756b83810..96e71f9513 100644 --- a/tests/cpp/test_common.cu +++ b/tests/cpp/test_common.cu @@ -543,6 +543,59 @@ void Tensor::set_scale(float scale) { } } +void Tensor::set_tensor_amax_shape(const std::vector &shape) { + const size_t numel = product(shape); + NVTE_CHECK(tensor_.scaling_mode() == NVTE_NVFP4_1D_SCALING, + "Amax shape override is only supported for NVFP4 test tensors."); + + auto old_amax = tensor_.get_amax(); + if (old_amax.data_ptr != nullptr) { + NVTE_CHECK_CUDA(cudaFree(old_amax.data_ptr)); + } + + float *amax = nullptr; + NVTE_CHECK_CUDA(cudaMalloc(&amax, numel * sizeof(float))); + NVTE_CHECK_CUDA(cudaMemset(amax, 0, numel * sizeof(float))); + tensor_.set_amax(amax, DType::kFloat32, shape); +} + +std::vector Tensor::tensor_amax_values() const { + const auto amax = tensor_.get_amax(); + NVTE_CHECK(static_cast(amax.dtype) == DType::kFloat32, "Tensor amax must be FP32."); + + const size_t numel = product(amax.shape); + if (numel == 0) { + return {}; + } + NVTE_CHECK(amax.data_ptr != nullptr, "Tensor amax is not allocated."); + + std::vector values(numel); + NVTE_CHECK_CUDA( + cudaMemcpy(values.data(), amax.data_ptr, numel * sizeof(float), cudaMemcpyDeviceToHost)); + return values; +} + +void Tensor::copy_tensor_amax_from(const Tensor &other) { + const auto other_amax = other.tensor_.get_amax(); + NVTE_CHECK(static_cast(other_amax.dtype) == DType::kFloat32, + "Source tensor amax must be FP32."); + + auto my_amax = tensor_.get_amax(); + NVTE_CHECK(static_cast(my_amax.dtype) == DType::kFloat32, + "Destination tensor amax must be FP32."); + NVTE_CHECK(areShapesEqual(my_amax.shape, other_amax.shape), "Amax shape mismatch."); + + const size_t numel = product(other_amax.shape); + if (numel == 0) { + return; + } + + NVTE_CHECK(other_amax.data_ptr != nullptr, "Source tensor amax is not allocated."); + NVTE_CHECK(my_amax.data_ptr != nullptr, "Destination tensor amax is not allocated."); + NVTE_CHECK_CUDA(cudaMemcpy(my_amax.data_ptr, other_amax.data_ptr, numel * sizeof(float), + cudaMemcpyDeviceToDevice)); +} + void Tensor::set_scale_inv(float scale_inv) { if (isFp8Type(dtype()) || isFp4Type(dtype())) { if (rowwise_) { diff --git a/tests/cpp/test_common.h b/tests/cpp/test_common.h index b8389d5833..b2a7da89cf 100644 --- a/tests/cpp/test_common.h +++ b/tests/cpp/test_common.h @@ -319,10 +319,18 @@ class Tensor { tensor_.set_amax(nullptr, DType::kFloat32, tensor_.defaultShape); } + void set_tensor_amax_shape(const std::vector &shape); + std::vector tensor_amax_values() const; + void copy_tensor_amax_from(const Tensor &other); + void set_with_gemm_swizzled_scales(bool with_gemm_swizzled_scales){ tensor_.set_with_gemm_swizzled_scales(with_gemm_swizzled_scales); } + void set_row_scaled_nvfp4(bool row_scaled_nvfp4) { + tensor_.set_row_scaled_nvfp4(row_scaled_nvfp4); + } + void to_cpu() const; void from_cpu() const; void set_scale(float scale); diff --git a/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py b/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py index 911b7660dc..b939336275 100644 --- a/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py +++ b/tests/pytorch/nvfp4/test_nvfp4_gemm_exact.py @@ -8,6 +8,7 @@ import transformer_engine_torch as tex from transformer_engine.pytorch.constants import TE_DType from transformer_engine.pytorch import NVFP4Quantizer +from transformer_engine.pytorch.cpp_extensions import general_gemm, general_grouped_gemm from transformer_engine.pytorch.custom_recipes.quantization_nvfp4 import NVFP4QuantizerRef from transformer_engine.pytorch.custom_recipes import utils @@ -26,6 +27,7 @@ def check_nvfp4_gemm_versus_reference( *, x_columnwise: bool = False, w_columnwise: bool = False, + row_scaled_nvfp4: bool = False, ): te_dtype = tex.DType.kFloat4E2M1 @@ -51,11 +53,12 @@ def check_nvfp4_gemm_versus_reference( x_quantizer = NVFP4Quantizer( fp4_dtype=te_dtype, rowwise=True, - columnwise=True, + columnwise=not row_scaled_nvfp4, with_amax_reduction=False, amax_reduction_group=None, with_rht=False, with_post_rht_amax=False, + row_scaled_nvfp4=row_scaled_nvfp4, ) w_quantizer = NVFP4Quantizer( fp4_dtype=te_dtype, @@ -112,7 +115,16 @@ def check_nvfp4_gemm_versus_reference( sw_trimmed = sw_trimmed.view(torch.float8_e4m3fn) # Create reference quantizer for reference GEMM - ref_quantizer = NVFP4QuantizerRef( + x_ref_quantizer = NVFP4QuantizerRef( + dtype=utils.Fp4Formats.E2M1, + rowwise=True, + columnwise=not row_scaled_nvfp4, + pow_2_scales=False, + eps=0.0, + quant_tile_shape=(1, 16), + row_scaled_nvfp4=row_scaled_nvfp4, + ) + w_ref_quantizer = NVFP4QuantizerRef( dtype=utils.Fp4Formats.E2M1, rowwise=True, columnwise=True, @@ -124,16 +136,16 @@ def check_nvfp4_gemm_versus_reference( # Create reference quantized tensors needed by reference GEMM # Reference GEMM is only rowwise. if x_columnwise: - x_nvfp4_ref = ref_quantizer.quantize(x.t().contiguous()) + x_nvfp4_ref = x_ref_quantizer.quantize(x.t().contiguous()) else: - x_nvfp4_ref = ref_quantizer.quantize(x) + x_nvfp4_ref = x_ref_quantizer.quantize(x) if w_columnwise: - w_nvfp4_ref = ref_quantizer.quantize(w.t().contiguous()) + w_nvfp4_ref = w_ref_quantizer.quantize(w.t().contiguous()) else: - w_nvfp4_ref = ref_quantizer.quantize(w) + w_nvfp4_ref = w_ref_quantizer.quantize(w) # Reference GEMM using quantizer's qgemm method - y_ref = ref_quantizer.qgemm( + y_ref = x_ref_quantizer.qgemm( qx=qx_data, qw=qw_data, m_params=None, # MMParams not used in reference @@ -166,27 +178,38 @@ def check_nvfp4_gemm_versus_reference( x_nvfp4_native.update_usage(rowwise_usage=False) if w_columnwise: w_nvfp4_native.update_usage(rowwise_usage=False) - # Native cuBLAS GEMM - # return type is out, bias_grad, gelu_input, extra_output - # We are just capturing out. - y_native = tex.generic_gemm( - w_nvfp4_native, - transa, - x_nvfp4_native, - transb, - out.clone() if accumulate else None, - out_quantizer, - TE_DType[out_dtype], - bias, - bias_dtype, - use_gelu, - gelu_input, - use_grad, - workspace, - workspace.shape[0], - accumulate, - use_split_accumulator, - )[0] + if row_scaled_nvfp4: + layout = ("T" if transa else "N") + ("T" if transb else "N") + y_native = general_gemm( + w_nvfp4_native, + x_nvfp4_native, + out_dtype=out_dtype, + accumulate=accumulate, + layout=layout, + out=out.clone() if accumulate else None, + )[0] + else: + # Native cuBLAS GEMM + # return type is out, bias_grad, gelu_input, extra_output + # We are just capturing out. + y_native = tex.generic_gemm( + w_nvfp4_native, + transa, + x_nvfp4_native, + transb, + out.clone() if accumulate else None, + out_quantizer, + TE_DType[out_dtype], + bias, + bias_dtype, + use_gelu, + gelu_input, + use_grad, + workspace, + workspace.shape[0], + accumulate, + use_split_accumulator, + )[0] # just in case of accumulation, make sure y_ref and y_native are not the same tensor assert y_ref is not y_native, "y_ref and y_native should not be the same tensor" @@ -199,6 +222,170 @@ def check_nvfp4_gemm_versus_reference( torch.testing.assert_close(y_native, y_ref, atol=8e-3, rtol=8e-3) +def check_nvfp4_row_scaled_grouped_gemm_matches_per_gemm( + x_dtype: torch.dtype, + w_dtype: torch.dtype, + out_dtype: torch.dtype, + m_splits: list[int], + k: int, + n: int, + *, + use_bias: bool, + single_output: bool, +): + te_dtype = tex.DType.kFloat4E2M1 + device = "cuda" + torch.manual_seed(23) + torch.cuda.manual_seed(23) + + num_gemms = len(m_splits) + + x_quantizer = NVFP4Quantizer( + fp4_dtype=te_dtype, + rowwise=True, + columnwise=False, + with_amax_reduction=False, + amax_reduction_group=None, + with_rht=False, + with_post_rht_amax=False, + row_scaled_nvfp4=True, + ) + w_quantizer = NVFP4Quantizer( + fp4_dtype=te_dtype, + rowwise=True, + columnwise=True, + with_amax_reduction=False, + amax_reduction_group=None, + with_rht=False, + with_post_rht_amax=False, + ) + + x_nvfp4 = [] + w_nvfp4 = [] + bias = [] + expected = [] + for m in m_splits: + x = torch.randn((m, k), dtype=x_dtype, device=device) + w = torch.randn((n, k), dtype=w_dtype, device=device) + x_nvfp4.append( + x_quantizer.update_quantized( + x, x_quantizer.make_empty(x.shape, dtype=x_dtype, device=device) + ) + ) + w_nvfp4.append( + w_quantizer.update_quantized( + w, w_quantizer.make_empty(w.shape, dtype=w_dtype, device=device) + ) + ) + bias.append(torch.randn(n, dtype=torch.bfloat16, device=device) if use_bias else None) + expected.append( + general_gemm( + w_nvfp4[-1], + x_nvfp4[-1], + out_dtype=out_dtype, + layout="TN", + bias=bias[-1], + )[0] + ) + + if single_output: + out = [torch.empty((sum(m_splits), n), dtype=out_dtype, device=device)] + else: + out = [torch.empty((m, n), dtype=out_dtype, device=device) for m in m_splits] + + grouped_out, _, _ = general_grouped_gemm( + w_nvfp4, + x_nvfp4, + out, + quantization_params=[None] * num_gemms, + out_dtype=out_dtype, + layout="TN", + m_splits=m_splits, + bias=bias, + use_bias=use_bias, + single_output=single_output, + ) + + if single_output: + grouped_slices = torch.split(grouped_out, m_splits, dim=0) + else: + grouped_slices = grouped_out + for grouped, ref in zip(grouped_slices, expected): + torch.testing.assert_close(grouped, ref, atol=0.0, rtol=0.0) + + +def check_nvfp4_row_scaled_gemm_matches_emulated( + x_dtype: torch.dtype, + w_dtype: torch.dtype, + out_dtype: torch.dtype, + M: int, + K: int, + N: int, +): + te_dtype = tex.DType.kFloat4E2M1 + device = "cuda" + torch.manual_seed(37) + torch.cuda.manual_seed(37) + + x = torch.randn((M, K), dtype=x_dtype, device=device) + w = torch.randn((N, K), dtype=w_dtype, device=device) + + x_row_scaled_quantizer = NVFP4Quantizer( + fp4_dtype=te_dtype, + rowwise=True, + columnwise=False, + with_amax_reduction=False, + amax_reduction_group=None, + with_rht=False, + with_post_rht_amax=False, + row_scaled_nvfp4=True, + ) + x_tensorwise_quantizer = NVFP4Quantizer( + fp4_dtype=te_dtype, + rowwise=True, + columnwise=True, + with_amax_reduction=False, + amax_reduction_group=None, + with_rht=False, + with_post_rht_amax=False, + ) + w_quantizer = NVFP4Quantizer( + fp4_dtype=te_dtype, + rowwise=True, + columnwise=True, + with_amax_reduction=False, + amax_reduction_group=None, + with_rht=False, + with_post_rht_amax=False, + ) + + x_row_scaled = x_row_scaled_quantizer.update_quantized( + x, x_row_scaled_quantizer.make_empty(x.shape, dtype=x_dtype, device=device) + ) + w_nvfp4 = w_quantizer.update_quantized( + w, w_quantizer.make_empty(w.shape, dtype=w_dtype, device=device) + ) + y_row_scaled = general_gemm(w_nvfp4, x_row_scaled, out_dtype=out_dtype, layout="TN")[0] + + emulated_rows = [] + for i in range(M): + x_padded = torch.zeros((16, K), dtype=x_dtype, device=device) + x_padded[0].copy_(x[i]) + x_tensorwise = x_tensorwise_quantizer.update_quantized( + x_padded, + x_tensorwise_quantizer.make_empty(x_padded.shape, dtype=x_dtype, device=device), + ) + emulated_rows.append( + general_gemm(w_nvfp4, x_tensorwise, out_dtype=out_dtype, layout="TN")[0][:1] + ) + + y_emulated = torch.cat(emulated_rows, dim=0) + if out_dtype == torch.bfloat16: + torch.testing.assert_close(y_row_scaled, y_emulated, atol=0.0, rtol=7.8e-3) + else: + torch.testing.assert_close(y_row_scaled, y_emulated, atol=3.0517578125e-5, rtol=0.0) + + @pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe) @pytest.mark.parametrize( "M, K, N", @@ -229,6 +416,7 @@ def check_nvfp4_gemm_versus_reference( ], ids=["rowxrow", "colxrow", "colxcol"], ) +@pytest.mark.parametrize("row_scaled_nvfp4", [False, True], ids=["nvfp4", "nvfp4_row_scaled"]) def test_nvfp4_gemm_versus_reference( M: int, K: int, @@ -239,7 +427,14 @@ def test_nvfp4_gemm_versus_reference( accumulate: bool, is_x_columnwise: bool, is_w_columnwise: bool, + row_scaled_nvfp4: bool, ): + if row_scaled_nvfp4: + if accumulate: + pytest.skip("Row-scaled NVFP4 GEMM output rescale does not support accumulation") + if is_x_columnwise: + pytest.skip("Row-scaled NVFP4 GEMM output rescale requires rowwise RHS usage") + check_nvfp4_gemm_versus_reference( x_dtype=x_dtype, w_dtype=w_dtype, @@ -250,4 +445,87 @@ def test_nvfp4_gemm_versus_reference( accumulate=accumulate, x_columnwise=is_x_columnwise, w_columnwise=is_w_columnwise, + row_scaled_nvfp4=row_scaled_nvfp4, + ) + + +@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe) +@pytest.mark.parametrize( + "m_splits, k, n", + [ + ([32, 48, 48], 128, 128), + ([64, 80, 112], 128, 256), + ([64, 80, 112], 256, 256), + ([64, 80, 112], 1024, 256), + ([256, 256, 512], 1024, 1024), + ([1024, 1536, 1536], 512, 3072), + ([16, 32, 64], 128, 96), + ([80, 96, 128], 640, 304), + ([320, 336, 352], 3072, 992), + ([64, 80, 112], 64, 256), + ([32, 48, 48], 128, 112), + ], +) +@pytest.mark.parametrize("x_dtype", [torch.float32, torch.bfloat16], ids=str) +@pytest.mark.parametrize("w_dtype", [torch.float32, torch.bfloat16], ids=str) +@pytest.mark.parametrize("out_dtype", [torch.float32, torch.bfloat16], ids=str) +@pytest.mark.parametrize("use_bias", [False, True], ids=["no_bias", "bias"]) +@pytest.mark.parametrize("single_output", [False, True], ids=["list_output", "single_output"]) +def test_nvfp4_row_scaled_grouped_gemm_matches_per_gemm( + m_splits: list[int], + k: int, + n: int, + x_dtype: torch.dtype, + w_dtype: torch.dtype, + out_dtype: torch.dtype, + use_bias: bool, + single_output: bool, +): + check_nvfp4_row_scaled_grouped_gemm_matches_per_gemm( + x_dtype=x_dtype, + w_dtype=w_dtype, + out_dtype=out_dtype, + m_splits=m_splits, + k=k, + n=n, + use_bias=use_bias, + single_output=single_output, + ) + + +@pytest.mark.skipif(not recipe_available, reason=reason_for_no_recipe) +@pytest.mark.parametrize( + "M, K, N", + [ + (128, 128, 128), + (256, 128, 256), + (256, 256, 256), + (256, 1024, 256), + (1024, 1024, 1024), + (4096, 512, 3072), + (112, 128, 96), + (304, 640, 304), + (1008, 3072, 992), + (256, 64, 256), + (128, 128, 112), + ], +) +@pytest.mark.parametrize("x_dtype", [torch.float32, torch.bfloat16], ids=str) +@pytest.mark.parametrize("w_dtype", [torch.float32, torch.bfloat16], ids=str) +@pytest.mark.parametrize("out_dtype", [torch.bfloat16, torch.float32], ids=str) +def test_nvfp4_row_scaled_gemm_matches_emulated( + M: int, + K: int, + N: int, + x_dtype: torch.dtype, + w_dtype: torch.dtype, + out_dtype: torch.dtype, +): + check_nvfp4_row_scaled_gemm_matches_emulated( + x_dtype=x_dtype, + w_dtype=w_dtype, + out_dtype=out_dtype, + M=M, + K=K, + N=N, ) diff --git a/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py b/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py index bf3f545b8b..0824a5e7bc 100644 --- a/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py +++ b/tests/pytorch/nvfp4/test_nvfp4_quantize_exact.py @@ -16,6 +16,19 @@ recipe_available, reason_for_no_recipe = te.is_nvfp4_available(return_reason=True) +def maybe_skip_row_scaled_unsupported_quantization( + row_scaled_nvfp4: bool, + return_transpose: bool, + with_2d_quantization: bool = False, +) -> None: + if not row_scaled_nvfp4: + return + if return_transpose: + pytest.skip("Row-scaled NVFP4 does not support columnwise usage") + if with_2d_quantization: + pytest.skip("Row-scaled NVFP4 does not support 2D quantization") + + def unpack_fp4(x: torch.Tensor) -> torch.Tensor: repeated = x.repeat_interleave(2, dim=1) repeated[:, 0::2] &= 0x0F @@ -31,7 +44,12 @@ def check_quantization_nvfp4_versus_reference( swizzled_scale: bool, use_cpp_allocator: bool, with_2d_quantization: bool, + row_scaled_nvfp4: bool = False, ) -> None: + maybe_skip_row_scaled_unsupported_quantization( + row_scaled_nvfp4, return_transpose, with_2d_quantization + ) + te_dtype = tex.DType.kFloat4E2M1 # Setup device and random seed @@ -52,6 +70,7 @@ def check_quantization_nvfp4_versus_reference( with_rht=False, with_post_rht_amax=False, with_2d_quantization=with_2d_quantization, + row_scaled_nvfp4=row_scaled_nvfp4, ) if use_cpp_allocator: x_nvfp4_sut = nvfp4_quantizer(x) @@ -73,6 +92,7 @@ def check_quantization_nvfp4_versus_reference( ) sx_t = x_nvfp4_sut._columnwise_scale_inv qx_amax = x_nvfp4_sut._amax_rowwise + qx_amax_t = x_nvfp4_sut._amax_columnwise # Reference quantization quant_tile_shape = (1, 16) if not with_2d_quantization else (16, 16) @@ -83,6 +103,7 @@ def check_quantization_nvfp4_versus_reference( pow_2_scales=False, eps=0.0, quant_tile_shape=quant_tile_shape, + row_scaled_nvfp4=row_scaled_nvfp4, ) x_nvfp4_ref = ref_quantizer.quantize(x) @@ -102,6 +123,7 @@ def check_quantization_nvfp4_versus_reference( x_nvfp4_ref.scale_t.view(dtype=torch.uint8) if x_nvfp4_ref.scale_t is not None else None ) ref_amax = x_nvfp4_ref.global_amax_row + ref_amax_t = x_nvfp4_ref.global_amax_col qx = unpack_fp4(qx) qx_t = unpack_fp4(qx_t) if qx_t is not None else None @@ -121,6 +143,7 @@ def check_quantization_nvfp4_versus_reference( ref_sx_t_shape = sx_t_ref.shape sx_t_valid = sx_t[: ref_sx_t_shape[0], : ref_sx_t_shape[1]] torch.testing.assert_close(sx_t_valid, sx_t_ref, atol=0.0, rtol=0.0) + torch.testing.assert_close(qx_amax_t, ref_amax_t, atol=0.0, rtol=0.0) torch.testing.assert_close(qx_amax, ref_amax, atol=0.0, rtol=0.0) @@ -155,6 +178,7 @@ def check_quantization_nvfp4_versus_reference( @pytest.mark.parametrize( "with_2d_quantization", [True, False], ids=["2d_quantization", "1d_quantization"] ) +@pytest.mark.parametrize("row_scaled_nvfp4", [False, True], ids=["nvfp4", "nvfp4_row_scaled"]) def test_quantization_block_tiling_versus_reference( x_dtype: torch.dtype, M: int, @@ -163,6 +187,7 @@ def test_quantization_block_tiling_versus_reference( swizzled_scale: bool, use_cpp_allocator: bool, with_2d_quantization: bool, + row_scaled_nvfp4: bool, ) -> None: check_quantization_nvfp4_versus_reference( x_dtype=x_dtype, @@ -172,6 +197,7 @@ def test_quantization_block_tiling_versus_reference( swizzled_scale=swizzled_scale, use_cpp_allocator=use_cpp_allocator, with_2d_quantization=with_2d_quantization, + row_scaled_nvfp4=row_scaled_nvfp4, ) @@ -188,6 +214,7 @@ def test_quantization_block_tiling_versus_reference( @pytest.mark.parametrize( "use_cpp_allocator", [True, False], ids=["cpp_allocator", "python_allocator"] ) +@pytest.mark.parametrize("row_scaled_nvfp4", [False, True], ids=["nvfp4", "nvfp4_row_scaled"]) def test_nvfp4_quantization_extrema_versus_reference( x_dtype: torch.dtype, M: int, @@ -195,7 +222,10 @@ def test_nvfp4_quantization_extrema_versus_reference( extrema_high: bool, return_transpose: bool, use_cpp_allocator: bool, + row_scaled_nvfp4: bool, ): + maybe_skip_row_scaled_unsupported_quantization(row_scaled_nvfp4, return_transpose) + te_dtype = tex.DType.kFloat4E2M1 device = "cuda" @@ -216,6 +246,7 @@ def test_nvfp4_quantization_extrema_versus_reference( amax_reduction_group=None, with_rht=False, with_post_rht_amax=False, + row_scaled_nvfp4=row_scaled_nvfp4, ) if use_cpp_allocator: @@ -237,6 +268,7 @@ def test_nvfp4_quantization_extrema_versus_reference( ) sx_t = x_nvfp4_sut._columnwise_scale_inv qx_amax = x_nvfp4_sut._amax_rowwise + qx_amax_t = x_nvfp4_sut._amax_columnwise ref_quantizer = NVFP4QuantizerRef( dtype=utils.Fp4Formats.E2M1, @@ -245,6 +277,7 @@ def test_nvfp4_quantization_extrema_versus_reference( pow_2_scales=False, eps=0.0, quant_tile_shape=(1, 16), + row_scaled_nvfp4=row_scaled_nvfp4, ) x_nvfp4_ref = ref_quantizer.quantize(x) @@ -257,6 +290,7 @@ def test_nvfp4_quantization_extrema_versus_reference( x_nvfp4_ref.scale_t.view(dtype=torch.uint8) if x_nvfp4_ref.scale_t is not None else None ) ref_amax = x_nvfp4_ref.global_amax_row + ref_amax_t = x_nvfp4_ref.global_amax_col torch.testing.assert_close(qx, qx_ref, atol=0.0, rtol=0.0) @@ -269,6 +303,7 @@ def test_nvfp4_quantization_extrema_versus_reference( ref_sx_t_shape = sx_t_ref.shape sx_t_valid = sx_t[: ref_sx_t_shape[0], : ref_sx_t_shape[1]] torch.testing.assert_close(sx_t_valid, sx_t_ref, atol=0.0, rtol=0.0) + torch.testing.assert_close(qx_amax_t, ref_amax_t, atol=0.0, rtol=0.0) torch.testing.assert_close(qx_amax, ref_amax, atol=0.0, rtol=0.0) @@ -286,18 +321,22 @@ def test_nvfp4_quantization_extrema_versus_reference( @pytest.mark.parametrize( "use_cpp_allocator", [True, False], ids=["cpp_allocator", "python_allocator"] ) +@pytest.mark.parametrize("row_scaled_nvfp4", [False, True], ids=["nvfp4", "nvfp4_row_scaled"]) def test_nvfp4_quantization_boundary_values( x_dtype: torch.dtype, M: int, N: int, return_transpose: bool, use_cpp_allocator: bool, + row_scaled_nvfp4: bool, ): """ Stress rounding/threshold behavior by placing values just below/above many potential bin edges within each 16-element microblock. Validates native vs reference byte-for-byte and scale parity. """ + maybe_skip_row_scaled_unsupported_quantization(row_scaled_nvfp4, return_transpose) + te_dtype = tex.DType.kFloat4E2M1 device = "cuda" @@ -327,6 +366,7 @@ def test_nvfp4_quantization_boundary_values( amax_reduction_group=None, with_rht=False, with_post_rht_amax=False, + row_scaled_nvfp4=row_scaled_nvfp4, ) if use_cpp_allocator: @@ -348,6 +388,7 @@ def test_nvfp4_quantization_boundary_values( ) sx_t = x_nvfp4_sut._columnwise_scale_inv qx_amax = x_nvfp4_sut._amax_rowwise + qx_amax_t = x_nvfp4_sut._amax_columnwise ref_quantizer = NVFP4QuantizerRef( dtype=utils.Fp4Formats.E2M1, @@ -356,6 +397,7 @@ def test_nvfp4_quantization_boundary_values( pow_2_scales=False, eps=0.0, quant_tile_shape=(1, 16), + row_scaled_nvfp4=row_scaled_nvfp4, ) x_nvfp4_ref = ref_quantizer.quantize(x) @@ -368,6 +410,7 @@ def test_nvfp4_quantization_boundary_values( x_nvfp4_ref.scale_t.view(dtype=torch.uint8) if x_nvfp4_ref.scale_t is not None else None ) ref_amax = x_nvfp4_ref.global_amax_row + ref_amax_t = x_nvfp4_ref.global_amax_col torch.testing.assert_close(qx, qx_ref, atol=0.0, rtol=0.0) @@ -381,6 +424,7 @@ def test_nvfp4_quantization_boundary_values( ref_sx_t_shape = sx_t_ref.shape sx_t_valid = sx_t[: ref_sx_t_shape[0], : ref_sx_t_shape[1]] torch.testing.assert_close(sx_t_valid, sx_t_ref, atol=0.0, rtol=0.0) + torch.testing.assert_close(qx_amax_t, ref_amax_t, atol=0.0, rtol=0.0) torch.testing.assert_close(qx_amax, ref_amax, atol=0.0, rtol=0.0) @@ -397,13 +441,17 @@ def test_nvfp4_quantization_boundary_values( @pytest.mark.parametrize( "use_cpp_allocator", [True, False], ids=["cpp_allocator", "python_allocator"] ) +@pytest.mark.parametrize("row_scaled_nvfp4", [False, True], ids=["nvfp4", "nvfp4_row_scaled"]) def test_nvfp4_quantization_noncontiguous_inputs( x_dtype: torch.dtype, M: int, N: int, return_transpose: bool, use_cpp_allocator: bool, + row_scaled_nvfp4: bool, ): + maybe_skip_row_scaled_unsupported_quantization(row_scaled_nvfp4, return_transpose) + te_dtype = tex.DType.kFloat4E2M1 device = "cuda" @@ -424,6 +472,7 @@ def test_nvfp4_quantization_noncontiguous_inputs( amax_reduction_group=None, with_rht=False, with_post_rht_amax=False, + row_scaled_nvfp4=row_scaled_nvfp4, ) if use_cpp_allocator: @@ -445,6 +494,7 @@ def test_nvfp4_quantization_noncontiguous_inputs( ) sx_t = x_nvfp4_sut._columnwise_scale_inv qx_amax = x_nvfp4_sut._amax_rowwise + qx_amax_t = x_nvfp4_sut._amax_columnwise ref_quantizer = NVFP4QuantizerRef( dtype=utils.Fp4Formats.E2M1, @@ -453,6 +503,7 @@ def test_nvfp4_quantization_noncontiguous_inputs( pow_2_scales=False, eps=0.0, quant_tile_shape=(1, 16), + row_scaled_nvfp4=row_scaled_nvfp4, ) x_nvfp4_ref = ref_quantizer.quantize(x_nc) @@ -465,6 +516,7 @@ def test_nvfp4_quantization_noncontiguous_inputs( x_nvfp4_ref.scale_t.view(dtype=torch.uint8) if x_nvfp4_ref.scale_t is not None else None ) ref_amax = x_nvfp4_ref.global_amax_row + ref_amax_t = x_nvfp4_ref.global_amax_col # Quantized must match torch.testing.assert_close(qx, qx_ref, atol=0.0, rtol=0.0) @@ -479,5 +531,6 @@ def test_nvfp4_quantization_noncontiguous_inputs( ref_sx_t_shape = sx_t_ref.shape sx_t_valid = sx_t[: ref_sx_t_shape[0], : ref_sx_t_shape[1]] torch.testing.assert_close(sx_t_valid, sx_t_ref, atol=0.0, rtol=0.0) + torch.testing.assert_close(qx_amax_t, ref_amax_t, atol=0.0, rtol=0.0) torch.testing.assert_close(qx_amax, ref_amax, atol=0.0, rtol=0.0) diff --git a/tests/pytorch/test_backward_override.py b/tests/pytorch/test_backward_override.py index ed4f73adbc..c7c5a5b99d 100644 --- a/tests/pytorch/test_backward_override.py +++ b/tests/pytorch/test_backward_override.py @@ -78,6 +78,11 @@ marks=pytest.mark.skipif(not nvfp4_available, reason=reason_for_no_nvfp4), id="NVFP4BlockScaling", ), + pytest.param( + "nvfp4_row_scaled", + marks=pytest.mark.skipif(not nvfp4_available, reason=reason_for_no_nvfp4), + id="NVFP4RowScaledBlockScaling", + ), ] @@ -165,7 +170,7 @@ def _maybe_skip_recipe_dtype( ) -> None: if dtype == torch.bfloat16 and not bf16_available: pytest.skip(reason_for_no_bf16) - if recipe_name == "nvfp4": + if recipe_name in ("nvfp4", "nvfp4_row_scaled"): if module_type in ("linear", "layernorm_linear") and dtype not in ( torch.bfloat16, torch.float32, @@ -178,6 +183,14 @@ def _maybe_skip_recipe_dtype( def _maybe_skip_unsupported_recipe_module_combo(recipe_name: str, module_type: str) -> None: if module_type == "ops_linear" and recipe_name == "fp8_block_scaling": pytest.skip("Fusible ops (te_ops.Linear) do not support Float8BlockScaling recipe") + if module_type == "ops_linear" and recipe_name == "nvfp4_row_scaled": + pytest.skip("Row-scaled NVFP4 currently does not support fused te_ops paths.") + + +def _make_quantized_forward_reference_recipe(recipe_name: str) -> recipe.Recipe: + if recipe_name == "nvfp4_row_scaled": + return make_recipe(recipe_name, backward_override="dequantized") + return make_recipe(recipe_name) def _maybe_skip_unsupported_recipe_shape( @@ -195,7 +208,9 @@ def _maybe_skip_unsupported_recipe_shape( " by 32." ) return - if recipe_name == "nvfp4" and (flat_first_dim % 16 != 0 or last_dim % 16 != 0): + if recipe_name in ("nvfp4", "nvfp4_row_scaled") and ( + flat_first_dim % 16 != 0 or last_dim % 16 != 0 + ): pytest.skip( "Linear/LayerNormLinear + NVFP4 requires prod(shape[:-1]) and shape[-1] divisible" " by 16." @@ -220,7 +235,9 @@ def _maybe_skip_unsupported_recipe_shape( pytest.skip( "te_ops.Linear + MXFP8 requires prod(shape[:-1]) and shape[-1] divisible by 32." ) - if recipe_name == "nvfp4" and (flat_first_dim % 16 != 0 or last_dim % 16 != 0): + if recipe_name in ("nvfp4", "nvfp4_row_scaled") and ( + flat_first_dim % 16 != 0 or last_dim % 16 != 0 + ): pytest.skip( "te_ops.Linear + NVFP4 requires prod(shape[:-1]) and shape[-1] divisible by 16." ) @@ -239,9 +256,9 @@ def _maybe_skip_unsupported_grouped_splits(recipe_name: str, m_splits: list[int] ) if recipe_name == "mxfp8" and any(m % 32 != 0 for m in non_empty_splits): pytest.skip("GroupedLinear + MXFP8 requires each non-empty m_split divisible by 32.") - if recipe_name == "nvfp4" and any(m % 16 != 0 for m in non_empty_splits): + if recipe_name in ("nvfp4", "nvfp4_row_scaled") and any(m % 16 != 0 for m in non_empty_splits): pytest.skip("GroupedLinear + NVFP4 requires each non-empty m_split divisible by 16.") - if recipe_name == "nvfp4" and any(m % 64 != 0 for m in non_empty_splits): + if recipe_name in ("nvfp4", "nvfp4_row_scaled") and any(m % 64 != 0 for m in non_empty_splits): pytest.skip( "GroupedLinear + NVFP4 grouped split_quantize currently requires each non-empty " "m_split divisible by 64 due to grouped amax kernel constraints." @@ -847,7 +864,7 @@ def test_linear_like_backward_override_matches_reference( _maybe_skip_unsupported_recipe_shape(recipe_name, input_shape, module_type) in_features = input_shape[-1] - quantized_ref_recipe = make_recipe(recipe_name) + quantized_ref_recipe = _make_quantized_forward_reference_recipe(recipe_name) mode_recipe = make_recipe(recipe_name, backward_override=backward_override) skip_unsupported_backward_override(module_type, mode_recipe, backward_override) @@ -1031,8 +1048,9 @@ def test_grouped_linear_backward_override_matches_reference( num_gemms = len(m_splits) num_tokens = sum(m_splits) - quantized_ref_recipe = make_recipe(recipe_name) + quantized_ref_recipe = _make_quantized_forward_reference_recipe(recipe_name) mode_recipe = make_recipe(recipe_name, backward_override=backward_override) + skip_unsupported_backward_override("grouped_linear", mode_recipe, backward_override) module_quantized_ref = te.GroupedLinear( num_gemms, @@ -1200,6 +1218,7 @@ def test_linear_like_runtime_backward_override_switch_updates_ctx( dy = torch.randn(*input_shape[:-1], out_features, dtype=dtype, device="cuda") default_recipe = make_recipe(recipe_name) + skip_unsupported_backward_override(module_type, default_recipe, None) mode_recipe = make_recipe(recipe_name, backward_override=backward_override) skip_unsupported_backward_override(module_type, mode_recipe, backward_override) @@ -1270,7 +1289,9 @@ def test_grouped_linear_runtime_backward_override_switch_updates_ctx( dy = torch.randn(num_tokens, out_features, dtype=dtype, device="cuda") default_recipe = make_recipe(recipe_name) + skip_unsupported_backward_override("grouped_linear", default_recipe, None) mode_recipe = make_recipe(recipe_name, backward_override=backward_override) + skip_unsupported_backward_override("grouped_linear", mode_recipe, backward_override) *_, default_ctx = _run_grouped_linear_single_step_with_ctx_state( module, @@ -1336,7 +1357,7 @@ def test_fused_linear_paths_match_backward_override_reference( reset_rng_states() - quantized_ref_recipe = make_recipe(recipe_name) + quantized_ref_recipe = _make_quantized_forward_reference_recipe(recipe_name) mode_recipe = make_recipe(recipe_name, backward_override=backward_override) skip_unsupported_backward_override("ops_linear", mode_recipe, backward_override) @@ -1476,7 +1497,7 @@ def test_fused_bias_activation_matches_masked_linear_backward( reset_rng_states() in_features = input_shape[-1] - quantized_ref_recipe = make_recipe(recipe_name) + quantized_ref_recipe = _make_quantized_forward_reference_recipe(recipe_name) mode_recipe = make_recipe(recipe_name, backward_override=backward_override) skip_unsupported_backward_override("ops_linear", mode_recipe, backward_override) @@ -1715,7 +1736,11 @@ def test_backward_override_memory_peak_report( x = torch.randn(*input_shape, dtype=dtype, device="cuda") dy = torch.randn(*input_shape[:-1], out_features, dtype=dtype, device="cuda") - modes = (None, "high_precision", "dequantized") + modes = ( + ("high_precision", "dequantized") + if recipe_name == "nvfp4_row_scaled" + else (None, "high_precision", "dequantized") + ) mode_results: dict[str, dict[str, float] | str] = {} for mode in modes: diff --git a/tests/pytorch/test_cuda_graphs.py b/tests/pytorch/test_cuda_graphs.py index a782dadc60..33ba65e0d9 100644 --- a/tests/pytorch/test_cuda_graphs.py +++ b/tests/pytorch/test_cuda_graphs.py @@ -20,17 +20,19 @@ is_fp8_available, is_fp8_block_scaling_available, is_mxfp8_available, + is_nvfp4_available, is_bf16_available, ) from transformer_engine.pytorch.quantization import FP8GlobalStateManager import transformer_engine.pytorch.ops as te_ops from transformer_engine.common import recipe -from utils import ModelConfig, reset_rng_states, skip_unsupported_backward_override +from utils import ModelConfig, recipe_id, reset_rng_states, skip_unsupported_backward_override # Check if FP8 is supported. fp8_available = is_fp8_available() fp8_block_scaling_available = is_fp8_block_scaling_available() mxfp8_available = is_mxfp8_available() +nvfp4_available = is_nvfp4_available() # Reset RNG states. reset_rng_states() @@ -62,6 +64,14 @@ def nvfp4_rht_and_2d_quantization(): return nvfp4_recipe +def nvfp4_row_scaled(): + nvfp4_recipe = recipe.NVFP4BlockScaling(row_scaled_activation=True) + nvfp4_recipe.fp4_quant_fwd_inp = recipe.QParams() + nvfp4_recipe.fp4_quant_fwd_weight = recipe.QParams() + nvfp4_recipe.fp4_quant_bwd_grad = recipe.QParams() + return nvfp4_recipe + + def check_rht_usage(recipe: recipe.Recipe) -> bool: # if using RHT, we can only support bf16 # check fp4_quant_fwd_inp, fp4_quant_fwd_weight, fp4_quant_bwd_grad @@ -88,7 +98,9 @@ def get_nvfp4_inp_supported_dtypes(recipe: recipe.Recipe, dtype: torch.dtype) -> fp8_recipes = [] if mxfp8_available: fp8_recipes.append(recipe.MXFP8BlockScaling()) +if nvfp4_available: fp8_recipes.append(nvfp4_rht_and_2d_quantization()) + fp8_recipes.append(nvfp4_row_scaled()) if fp8_block_scaling_available: fp8_recipes.append(recipe.Float8BlockScaling()) if fp8_available: @@ -360,7 +372,7 @@ def _test_cuda_graphs( @pytest.mark.parametrize("module", _test_cuda_graphs_modules) @pytest.mark.parametrize("dtype", dtypes) @pytest.mark.parametrize("fp8_params", (False, True)) -@pytest.mark.parametrize("fp8_recipe", fp8_recipes + [None], ids=lambda r: type(r).__name__) +@pytest.mark.parametrize("fp8_recipe", fp8_recipes + [None], ids=recipe_id) @pytest.mark.parametrize("backward_override", (None, "high_precision", "dequantized")) def test_make_graphed_callables( *, @@ -390,6 +402,8 @@ def test_make_graphed_callables( f"Module not yet supported for {fp8_recipe.__class__.__name__} with CUDA graphs" ) if fp8 and fp8_recipe.nvfp4(): + if getattr(fp8_recipe, "row_scaled_activation", False) and module == "mha": + pytest.skip("Row-scaled NVFP4 CUDA graph coverage applies to GEMM modules.") if dtype not in get_nvfp4_inp_supported_dtypes(fp8_recipe, dtype): pytest.skip( f"Input dtype {dtype} not supported for NVFP4 Recipe" @@ -448,7 +462,7 @@ def test_make_graphed_callables( ) @pytest.mark.parametrize("dtype", dtypes) @pytest.mark.parametrize("fp8_params", (False, True)) -@pytest.mark.parametrize("fp8_recipe", fp8_recipes, ids=lambda r: type(r).__name__) +@pytest.mark.parametrize("fp8_recipe", fp8_recipes, ids=recipe_id) @pytest.mark.parametrize("backward_override", (None, "high_precision", "dequantized")) def test_make_graphed_callables_with_fp8_weight_caching( *, diff --git a/tests/pytorch/test_recipe.py b/tests/pytorch/test_recipe.py index 91d4b89013..5f5221af76 100644 --- a/tests/pytorch/test_recipe.py +++ b/tests/pytorch/test_recipe.py @@ -25,10 +25,16 @@ import transformer_engine_torch as tex from transformer_engine.pytorch.quantization import ( FP8GlobalStateManager, + NVFP4BlockScalingRecipeState, _amax_and_scale_update, ) import transformer_engine.pytorch.ops as te_ops -from transformer_engine.common.recipe import DelayedScaling, Float8BlockScaling, MXFP8BlockScaling +from transformer_engine.common.recipe import ( + DelayedScaling, + Float8BlockScaling, + MXFP8BlockScaling, + NVFP4BlockScaling, +) # Check if FP8 is supported fp8_available, reason_for_no_fp8 = te.is_fp8_available(return_reason=True) @@ -507,8 +513,30 @@ def test_quantizer_update(self, module_class): y = module(x) +@pytest.mark.skipif(not fp4_available, reason=reason_for_no_fp4) +def test_nvfp4_row_scaled_quantizer_roles(): + recipe = NVFP4BlockScaling(row_scaled_activation=True) + + forward_quantizers = NVFP4BlockScalingRecipeState( + recipe, + mode="forward", + num_quantizers=3, + ).make_quantizers() + assert [q.row_scaled_nvfp4 for q in forward_quantizers] == [True, False, True] + assert not forward_quantizers[0].is_quantizable(torch.empty(16, 16)) + assert forward_quantizers[1].is_quantizable(torch.empty(16, 16)) + + backward_quantizers = NVFP4BlockScalingRecipeState( + recipe, + mode="backward", + num_quantizers=2, + ).make_quantizers() + assert [q.row_scaled_nvfp4 for q in backward_quantizers] == [False, False] + + @pytest.mark.skipif(not fp4_available, reason=reason_for_no_fp4) @pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16], ids=str) +@pytest.mark.parametrize("row_scaled_nvfp4", [False, True], ids=["nvfp4", "nvfp4_row_scaled"]) @pytest.mark.parametrize( "M, N", [ @@ -524,12 +552,19 @@ def test_quantizer_update(self, module_class): (8192, 8192), ], ) -def test_fp4_dequantize(dtype, M, N): - q = NVFP4Quantizer() +def test_fp4_dequantize(dtype, row_scaled_nvfp4, M, N): + q = NVFP4Quantizer( + columnwise=not row_scaled_nvfp4, + row_scaled_nvfp4=row_scaled_nvfp4, + ) a = torch.rand((M, N)).cuda().to(dtype=dtype) starting_tensor = q(a) + assert starting_tensor._row_scaled_nvfp4 == row_scaled_nvfp4 + assert starting_tensor._amax_rowwise.numel() == (M if row_scaled_nvfp4 else 1) dequantized_tensor = starting_tensor.dequantize() new_tensor = q(dequantized_tensor) + assert new_tensor._row_scaled_nvfp4 == row_scaled_nvfp4 + assert new_tensor._amax_rowwise.numel() == (M if row_scaled_nvfp4 else 1) torch.testing.assert_close( new_tensor._rowwise_data, starting_tensor._rowwise_data, diff --git a/tests/pytorch/test_sanity.py b/tests/pytorch/test_sanity.py index 7f2f24fd69..c811342df5 100644 --- a/tests/pytorch/test_sanity.py +++ b/tests/pytorch/test_sanity.py @@ -38,12 +38,13 @@ import transformer_engine_torch as tex from transformer_engine.pytorch.cpp_extensions import general_gemm from transformer_engine.pytorch.tensor.utils import replace_raw_data -from utils import ModelConfig, skip_unsupported_backward_override +from utils import ModelConfig, recipe_id, skip_unsupported_backward_override # Only run FP8 tests on supported devices. fp8_available, reason_for_no_fp8 = te.is_fp8_available(return_reason=True) fp8_block_scaling_available, _ = te.is_fp8_block_scaling_available(return_reason=True) mxfp8_available, reason_for_no_mxfp8 = te.is_mxfp8_available(return_reason=True) +nvfp4_available, _ = te.is_nvfp4_available(return_reason=True) # Record initial RNG state from script run. seed = 1234 @@ -93,9 +94,18 @@ def nvfp4_vanilla(): return nvfp4_recipe +def nvfp4_row_scaled(): + nvfp4_recipe = recipe.NVFP4BlockScaling(row_scaled_activation=True) + nvfp4_recipe.fp4_quant_fwd_inp = recipe.QParams() + nvfp4_recipe.fp4_quant_fwd_weight = recipe.QParams() + nvfp4_recipe.fp4_quant_bwd_grad = recipe.QParams() + return nvfp4_recipe + + fp8_recipes = [] if mxfp8_available: fp8_recipes.append(recipe.MXFP8BlockScaling()) +if nvfp4_available: fp8_recipes.append(nvfp4_vanilla()) # TODO: fix check for this if fp8_block_scaling_available: fp8_recipes.append(recipe.Float8BlockScaling()) @@ -103,6 +113,9 @@ def nvfp4_vanilla(): fp8_recipes.append(recipe.Float8CurrentScaling()) fp8_recipes.append(recipe.DelayedScaling()) fp8_recipes.append(None) +fp8_recipes_with_row_scaled = fp8_recipes.copy() +if nvfp4_available: + fp8_recipes_with_row_scaled.insert(-1, nvfp4_row_scaled()) param_types = [torch.float32, torch.float16] if is_bf16_available(): # bf16 requires sm_80 or higher @@ -402,7 +415,7 @@ def test_sanity_normalization_amp(dtype, model, skip_wgrad, skip_dgrad, normaliz @pytest.mark.parametrize("dtype", param_types) -@pytest.mark.parametrize("fp8_recipe", fp8_recipes) +@pytest.mark.parametrize("fp8_recipe", fp8_recipes_with_row_scaled, ids=recipe_id) @pytest.mark.parametrize("backward_override", [None, "high_precision", "dequantized"]) @pytest.mark.parametrize("model", ["small", "weird"]) @pytest.mark.parametrize("skip_wgrad", all_boolean) @@ -450,7 +463,7 @@ def test_sanity_layernorm_linear( @pytest.mark.parametrize("dtype", param_types) -@pytest.mark.parametrize("fp8_recipe", fp8_recipes) +@pytest.mark.parametrize("fp8_recipe", fp8_recipes_with_row_scaled, ids=recipe_id) @pytest.mark.parametrize("backward_override", [None, "high_precision", "dequantized"]) @pytest.mark.parametrize("model", ["small", "weird"]) @pytest.mark.parametrize("skip_wgrad", all_boolean) @@ -488,7 +501,7 @@ def test_sanity_linear( @pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("bs", batch_sizes_with_zero) @pytest.mark.parametrize("model", ["small", "weird"]) -@pytest.mark.parametrize("fp8_recipe", fp8_recipes) +@pytest.mark.parametrize("fp8_recipe", fp8_recipes_with_row_scaled, ids=recipe_id) @pytest.mark.parametrize("backward_override", [None, "high_precision", "dequantized"]) @pytest.mark.parametrize("fp8_model_params", all_boolean) @pytest.mark.parametrize("use_bias", all_boolean) @@ -529,7 +542,7 @@ def test_sanity_linear_with_zero_tokens( @pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("bs", batch_sizes_with_zero) @pytest.mark.parametrize("model", ["small", "weird"]) -@pytest.mark.parametrize("fp8_recipe", fp8_recipes) +@pytest.mark.parametrize("fp8_recipe", fp8_recipes_with_row_scaled, ids=recipe_id) @pytest.mark.parametrize("backward_override", [None, "high_precision", "dequantized"]) @pytest.mark.parametrize("fp8_model_params", all_boolean) @pytest.mark.parametrize("use_bias", all_boolean) @@ -563,7 +576,12 @@ def test_sanity_grouped_linear( if not is_fp8_supported(config): pytest.skip("Model config does not support FP8") if fp8_recipe.nvfp4(): - pytest.skip("NVFP4 not supported for grouped linear") + if not getattr(fp8_recipe, "row_scaled_activation", False): + pytest.skip("NVFP4 not supported for grouped linear") + if single_param: + pytest.skip("Row-scaled NVFP4 does not support GroupedTensor grouped linear") + if dtype == torch.float16: + pytest.skip("FP16 output for NVFP4 not supported") use_fp8 = fp8_recipe is not None with quantized_model_init(enabled=use_fp8 and fp8_model_params, recipe=fp8_recipe): diff --git a/tests/pytorch/test_torch_compile.py b/tests/pytorch/test_torch_compile.py index 9d0ed79888..51f72b1e56 100644 --- a/tests/pytorch/test_torch_compile.py +++ b/tests/pytorch/test_torch_compile.py @@ -32,6 +32,7 @@ is_fp8_block_scaling_available, is_nvfp4_available, ) +from utils import recipe_id fp8_available, reason_for_no_fp8 = is_fp8_available(return_reason=True) mxfp8_available, reason_for_no_mxfp8 = is_mxfp8_available(return_reason=True) @@ -47,6 +48,7 @@ _all_recipes.append(recipe.MXFP8BlockScaling()) if nvfp4_available: _all_recipes.append(recipe.NVFP4BlockScaling()) + _all_recipes.append(recipe.NVFP4BlockScaling(row_scaled_activation=True)) # --------------------------------------------------------------------------- @@ -303,7 +305,7 @@ def fn(inp): @pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) -@pytest.mark.parametrize("fp8_recipe", _all_recipes, ids=lambda r: type(r).__name__) +@pytest.mark.parametrize("fp8_recipe", _all_recipes, ids=recipe_id) def test_autocast_sanity(fp8_recipe): """Smoke test: torch.nn.Linear inside a single te.autocast with each built-in recipe. Forward + backward under torch.compile(fullgraph=True).""" diff --git a/tests/pytorch/utils.py b/tests/pytorch/utils.py index c7cbe78a6d..8ca796c268 100644 --- a/tests/pytorch/utils.py +++ b/tests/pytorch/utils.py @@ -117,7 +117,7 @@ def quantization_tols(name: str) -> dict[str, float]: "mxfp8_block_scaling", ): return dtype_tols(tex.DType.kFloat8E4M3) - if name == "nvfp4": + if name in ("nvfp4", "nvfp4_row_scaled"): return dtype_tols(tex.DType.kFloat4E2M1) raise ValueError(f"Unsupported quantization scheme ({name})") @@ -151,15 +151,40 @@ def make_recipe(name: Optional[str], **recipe_kwargs: Any) -> Optional[Recipe]: disable_2d_quantization=True, **recipe_kwargs, ) + if name == "nvfp4_row_scaled": + return transformer_engine.common.recipe.NVFP4BlockScaling( + disable_rht=True, + disable_stochastic_rounding=True, + disable_2d_quantization=True, + row_scaled_activation=True, + **recipe_kwargs, + ) raise ValueError(f"Unsupported quantization scheme ({name})") +def recipe_id(fp8_recipe: Optional[Recipe]) -> str: + """Readable pytest id for FP8/FP4 recipes.""" + if fp8_recipe is None: + return "None" + nvfp4 = getattr(fp8_recipe, "nvfp4", None) + if nvfp4 is not None and nvfp4() and getattr(fp8_recipe, "row_scaled_activation", False): + return "NVFP4RowScaledBlockScaling" + return type(fp8_recipe).__name__ + + def skip_unsupported_backward_override( layer_type: str, quant_recipe: Optional[Recipe], backward_override: Optional[str], ) -> None: """Skip known unsupported layer/recipe/backward-override combinations used in tests.""" + if ( + quant_recipe is not None + and quant_recipe.nvfp4() + and getattr(quant_recipe, "row_scaled_activation", False) + and backward_override is None + ): + pytest.skip("Row-scaled NVFP4 does not support default quantized backward.") if backward_override is None: return if quant_recipe is None and backward_override is not None: diff --git a/transformer_engine/common/cast/dispatch/quantize.cuh b/transformer_engine/common/cast/dispatch/quantize.cuh index 5d0d3c28e8..123362ce10 100644 --- a/transformer_engine/common/cast/dispatch/quantize.cuh +++ b/transformer_engine/common/cast/dispatch/quantize.cuh @@ -100,6 +100,14 @@ void quantize_fwd_helper(const NVTETensor input, NVTETensor output, int32_t rows = input_tensor->flat_first_dim(); int32_t cols = input_tensor->flat_last_dim(); auto dtype = input_tensor->dtype(); + const bool row_scaled_nvfp4 = output_tensor->row_scaled_nvfp4; + if (row_scaled_nvfp4) { + NVTE_CHECK(!quant_config_cpp.nvfp4_2d_quantization, + "Row-scaled NVFP4 quantization does not support 2D quantization."); + NVTE_CHECK(!output_tensor->has_columnwise_data(), + "Row-scaled NVFP4 quantization does not produce columnwise output."); + nvfp4::compute_rowwise_amax(*input_tensor, noop_tensor, output_tensor, stream); + } bool use_optimized_kernel = (dtype == DType::kBFloat16) && (rows % 32 == 0) && (cols % 32 == 0) && output_tensor->has_data(); @@ -126,7 +134,9 @@ void quantize_fwd_helper(const NVTETensor input, NVTETensor output, /*use_stochastic_rounding=*/quant_config_cpp.stochastic_rounding, /*rng_state=*/quant_config_cpp.rng_state, /*use_2d_quantization=*/quant_config_cpp.nvfp4_2d_quantization, - /*noop_tensor=*/noop_tensor->data, /*stream=*/stream); + /*row_scaled_nvfp4=*/row_scaled_nvfp4, + /*noop_tensor=*/noop_tensor->data, + /*stream=*/stream); } break; } @@ -239,6 +249,8 @@ void quantize_bwd_helper(const NVTETensor grad, const NVTETensor input, NVTETens int32_t rows = grad_tensor->flat_first_dim(); int32_t cols = grad_tensor->flat_last_dim(); auto dtype = grad_tensor->dtype(); + NVTE_CHECK(!output_tensor->row_scaled_nvfp4, + "Backward NVFP4 quantization does not support row-scaled outputs."); bool use_optimized_kernel = (dtype == DType::kBFloat16) && (rows % 32 == 0) && (cols % 32 == 0) && output_tensor->has_data(); @@ -265,7 +277,8 @@ void quantize_bwd_helper(const NVTETensor grad, const NVTETensor input, NVTETens /*use_stochastic_rounding=*/quant_config_cpp.stochastic_rounding, /*rng_state=*/quant_config_cpp.rng_state, /*use_2d_quantization=*/quant_config_cpp.nvfp4_2d_quantization, - /*noop_tensor=*/noop_tensor->data, /*stream=*/stream); + /*row_scaled_nvfp4=*/false, /*noop_tensor=*/noop_tensor->data, + /*stream=*/stream); } break; } diff --git a/transformer_engine/common/cast/nvfp4/dequantize_nvfp4.cuh b/transformer_engine/common/cast/nvfp4/dequantize_nvfp4.cuh index 4143208153..d549a050ee 100644 --- a/transformer_engine/common/cast/nvfp4/dequantize_nvfp4.cuh +++ b/transformer_engine/common/cast/nvfp4/dequantize_nvfp4.cuh @@ -34,8 +34,9 @@ namespace dequantize_kernel { template __global__ void __launch_bounds__(512) dequantize_fp4_kernel(const void *const input, OType *output, const fp8e4m3 *const scales, - const float *const tensor_amax, const size_t N, const size_t M, - const size_t scale_stride, const size_t num_scale_tiles_X) { + const float *const tensor_amax, const bool row_scaled_nvfp4, + const size_t N, const size_t M, const size_t scale_stride, + const size_t num_scale_tiles_X) { const size_t thread_idx = blockIdx.x * blockDim.x + threadIdx.x; const size_t x = thread_idx % M; const size_t y = thread_idx / M; @@ -63,7 +64,7 @@ __global__ void __launch_bounds__(512) fp4vec value; value.vec = input_vectorized[my_index]; fp8e4m3 scale = scales[my_scale_index]; - float amax = *tensor_amax; + float amax = row_scaled_nvfp4 ? tensor_amax[y] : tensor_amax[0]; constexpr float factor_inv = 1.0 / (6.0 * 448.0); float final_scale = static_cast(scale) * amax * factor_inv; #pragma unroll @@ -90,6 +91,7 @@ inline void dequantize(const Tensor &input, Tensor *output, cudaStream_t stream) NVTE_CHECK(output->data.shape == input.data.shape, "Input and output shapes need to match."); const bool with_gemm_swizzled_scales = input.with_gemm_swizzled_scales; + const bool row_scaled_nvfp4 = input.row_scaled_nvfp4; constexpr int FP4_BLOCK_SIZE = 16; const size_t N = input.flat_first_dim(); @@ -103,6 +105,8 @@ inline void dequantize(const Tensor &input, Tensor *output, cudaStream_t stream) const size_t threads = 512; const size_t blocks = DIVUP(total, threads); const size_t num_scale_tiles_X = DIVUP(Mread, static_cast(4)); + NVTE_CHECK(!row_scaled_nvfp4 || input.amax.numel() == N, + "Row-scaled NVFP4 dequantization requires one rowwise amax per row."); TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( output->data.dtype, OType, @@ -112,7 +116,8 @@ inline void dequantize(const Tensor &input, Tensor *output, cudaStream_t stream) dequantize_fp4_kernel<<>>( input.data.dptr, reinterpret_cast(output->data.dptr), reinterpret_cast(input.scale_inv.dptr), - reinterpret_cast(input.amax.dptr), N, Mread, input.scale_inv.shape.back(), + reinterpret_cast(input.amax.dptr), row_scaled_nvfp4, N, Mread, + input.scale_inv.shape.back(), num_scale_tiles_X);); // NOLINT(*) ); // NOLINT(*) NVTE_CHECK_CUDA(cudaGetLastError()); diff --git a/transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh b/transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh index f164636e38..9e4aef5a1c 100644 --- a/transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh +++ b/transformer_engine/common/cast/nvfp4/quantize_transpose_nvfp4.cuh @@ -16,6 +16,8 @@ #include #include +#include + #include "../../common.h" #include "../../util/math.h" #include "../../util/ptx.cuh" @@ -27,6 +29,132 @@ namespace transformer_engine { namespace dispatch { namespace nvfp4 { +namespace rowwise_amax_kernel { + +using namespace ptx; + +#if FP4_TYPE_SUPPORTED + +constexpr int ROWWISE_AMAX_BLOCK_SIZE = 256; +constexpr int ROWWISE_AMAX_SF_VEC_SIZE = 16; + +template +__device__ __forceinline__ void abs_max_2x_update(ptx::FPx2 &dst, + const ptx::FPx2 &val) { + if constexpr (std::is_same_v) { + dst.x = fmaxf(fabsf(dst.x), fabsf(val.x)); + dst.y = fmaxf(fabsf(dst.y), fabsf(val.y)); + } else { + ptx::abs_max_2x(dst, dst, val); + } +} + +template +__device__ __forceinline__ float abs_max_2x_to_float(const ptx::FPx2 &val) { + if constexpr (std::is_same_v) { + return fmaxf(fabsf(val.x), fabsf(val.y)); + } else { + return static_cast(__hmax(__habs(val.x), __habs(val.y))); + } +} + +template +__global__ void +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +__launch_bounds__(BLOCK_SIZE) +#endif + compute_rowwise_amax_kernel(const int num_rows, const int num_cols, + const IType *__restrict__ input, + float *__restrict__ output_rowwise_amax, + const float *__restrict__ noop) { +#if !defined(__CUDA_ARCH__) || (__CUDA_ARCH__ < 1000) + NVTE_DEVICE_ERROR("SM 10.0+ is required."); +#else + if (noop != nullptr && noop[0] == 1.0f) { + return; + } + + using IType2 = typename ptx::FPx2; + + const int row_idx = blockIdx.x; + if (row_idx >= num_rows) return; + + const int num_vec2 = num_cols / 2; + const IType2 *input_row = reinterpret_cast(input + row_idx * num_cols); + + IType2 thread_amax_2x = {static_cast(0.0f), static_cast(0.0f)}; + for (int i = threadIdx.x; i < num_vec2; i += BLOCK_SIZE) { + const IType2 val = input_row[i]; + abs_max_2x_update(thread_amax_2x, val); + } + const float thread_max = abs_max_2x_to_float(thread_amax_2x); + + const float row_amax = + reduce_max(thread_max, threadIdx.x / THREADS_PER_WARP); + + if (threadIdx.x == 0) { + output_rowwise_amax[row_idx] = row_amax; + } +#endif +} + +template +void launch_compute_rowwise_amax(const int num_rows, const int num_cols, const IType *input, + float *output_rowwise_amax, cudaStream_t stream, + const float *noop = nullptr) { + if (num_rows == 0 || num_cols == 0) return; + + dim3 grid(num_rows); + dim3 block(ROWWISE_AMAX_BLOCK_SIZE); + + compute_rowwise_amax_kernel + <<>>(num_rows, num_cols, input, output_rowwise_amax, noop); + NVTE_CHECK_CUDA(cudaGetLastError()); +} + +#endif // FP4_TYPE_SUPPORTED + +} // namespace rowwise_amax_kernel + +inline void compute_rowwise_amax(const Tensor &input, const Tensor *noop, Tensor *output, + cudaStream_t stream) { +#if FP4_TYPE_SUPPORTED + using namespace rowwise_amax_kernel; + + const size_t rows = input.flat_first_dim(); + const size_t cols = input.flat_last_dim(); + NVTE_CHECK(cols % ROWWISE_AMAX_SF_VEC_SIZE == 0, + "Row-scaled NVFP4 quantization requires last dim divisible by ", + ROWWISE_AMAX_SF_VEC_SIZE, "."); + + auto *amax_ptr = reinterpret_cast(output->amax.dptr); + NVTE_CHECK(amax_ptr != nullptr, "Row-scaled rowwise amax tensor must be allocated."); + NVTE_CHECK(output->amax.numel() == rows, "Row-scaled rowwise amax must have ", rows, + " entries, got ", output->amax.shape, "."); + + const auto *noop_ptr = reinterpret_cast(noop->data.dptr); + if (input.dtype() == DType::kBFloat16) { + const auto *input_ptr = reinterpret_cast(input.data.dptr); + launch_compute_rowwise_amax<__nv_bfloat16>(static_cast(rows), static_cast(cols), + input_ptr, amax_ptr, stream, noop_ptr); + } else if (input.dtype() == DType::kFloat16) { + const auto *input_ptr = reinterpret_cast(input.data.dptr); + launch_compute_rowwise_amax(static_cast(rows), static_cast(cols), input_ptr, + amax_ptr, stream, noop_ptr); + } else if (input.dtype() == DType::kFloat32) { + const auto *input_ptr = reinterpret_cast(input.data.dptr); + launch_compute_rowwise_amax(static_cast(rows), static_cast(cols), input_ptr, + amax_ptr, stream, noop_ptr); + } else { + NVTE_ERROR( + "Unsupported input dtype for row-scaled NVFP4 quantization. " + "Expected BFloat16, Float16, or Float32."); + } +#else + NVTE_ERROR("FP4 support requires CUDA 12.8+, but compile-time CUDA version is ", CUDA_VERSION); +#endif // FP4_TYPE_SUPPORTED +} + namespace quantize_transpose_kernel { using namespace quantization_and_transposition_SF; @@ -108,7 +236,8 @@ constexpr size_t TOTAL_BANKS_WIDTH = (32 * 4 * 8) / 4; // 256 constexpr size_t THREADS_PER_BANK = TOTAL_BANKS_WIDTH / SCALE_DIM; // 8 = 128 / 16 template + typename IType, bool USE_STOCHASTIC_ROUNDING, bool RETURN_TRANSPOSE, + bool ROW_SCALED_NVFP4> __global__ void __launch_bounds__(THREADS_NUM) quantize_transpose_nvfp4_kernel(const __grid_constant__ CUtensorMap tensor_map_input, const __grid_constant__ CUtensorMap tensor_map_output, @@ -508,27 +637,56 @@ __global__ void __launch_bounds__(THREADS_NUM) } } - // 2. Compute E4M3 scaling factor - const nvfp4_scale_t S_dec_b_fp8 = - compute_decoding_scaling_factor(block_amax, S_enc_rowwise); + float block_scale_inverse; + if constexpr (ROW_SCALED_NVFP4) { + // 2. Compute E4M3 scaling factor + const size_t scales_offset_Y = + scales_offset_Y_rowwise + stage * BUFF_DIM_Y + it * THREADS_Y_ROWWISE; + const float S_enc_rowwise_block = + scales_offset_Y < rows + ? compute_global_encode_scaling_factor_FP4(amax_rowwise_ptr[scales_offset_Y]) + : 1.0f; + const float S_dec_rowwise_block = 1.0f / S_enc_rowwise_block; + const nvfp4_scale_t S_dec_b_fp8 = + compute_decoding_scaling_factor(block_amax, S_enc_rowwise_block); + + // Check boundaries + const size_t scales_offset_X = scales_offset_X_rowwise; + const size_t scale_idx_global = scales_offset_Y * scale_stride + scales_offset_X; + + const bool rowwise_scale_is_within_bounds_Y = + (stage_rowwise_scales_offset_Y + it * THREADS_Y_ROWWISE + tid_Y_rowwise) < chunk_rows; + if (rowwise_scale_is_within_bounds_X && rowwise_scale_is_within_bounds_Y) { + scales_ptr[scale_idx_global] = S_dec_b_fp8; + } - // Check boundaries - const size_t scales_offset_Y = - scales_offset_Y_rowwise + stage * BUFF_DIM_Y + it * THREADS_Y_ROWWISE; - const size_t scales_offset_X = scales_offset_X_rowwise; - const size_t scale_idx_global = scales_offset_Y * scale_stride + scales_offset_X; + // Compute "correct" per-block encoding scaling factor + constexpr float float_max = detail::TypeExtrema::max; + block_scale_inverse = + fminf(1.0f / (static_cast(S_dec_b_fp8) * S_dec_rowwise_block), + float_max); // S_enc_b_fp8 + } else { + // 2. Compute E4M3 scaling factor + const nvfp4_scale_t S_dec_b_fp8 = + compute_decoding_scaling_factor(block_amax, S_enc_rowwise); + + // Check boundaries + const size_t scales_offset_Y = + scales_offset_Y_rowwise + stage * BUFF_DIM_Y + it * THREADS_Y_ROWWISE; + const size_t scales_offset_X = scales_offset_X_rowwise; + const size_t scale_idx_global = scales_offset_Y * scale_stride + scales_offset_X; + + const bool rowwise_scale_is_within_bounds_Y = + (stage_rowwise_scales_offset_Y + it * THREADS_Y_ROWWISE + tid_Y_rowwise) < chunk_rows; + if (rowwise_scale_is_within_bounds_X && rowwise_scale_is_within_bounds_Y) { + scales_ptr[scale_idx_global] = S_dec_b_fp8; + } - // const bool rowwise_scale_is_within_bounds_Y = scales_offset_Y < rows; - const bool rowwise_scale_is_within_bounds_Y = - (stage_rowwise_scales_offset_Y + it * THREADS_Y_ROWWISE + tid_Y_rowwise) < chunk_rows; - if (rowwise_scale_is_within_bounds_X && rowwise_scale_is_within_bounds_Y) { - scales_ptr[scale_idx_global] = S_dec_b_fp8; + // Compute "correct" per-block encoding scaling factor + constexpr float float_max = detail::TypeExtrema::max; + block_scale_inverse = fminf(1.0f / (static_cast(S_dec_b_fp8) * S_dec_rowwise), + float_max); // S_enc_b_fp8 } - - // Compute "correct" per-block encoding scaling factor - constexpr float float_max = detail::TypeExtrema::max; - const float block_scale_inverse = fminf( - 1.0f / (static_cast(S_dec_b_fp8) * S_dec_rowwise), float_max); // S_enc_b_fp8 const float2 block_scale_inverse_2x{block_scale_inverse, block_scale_inverse}; // 3. Scale elements @@ -1051,7 +1209,6 @@ __global__ void __launch_bounds__(THREADS_NUM) const size_t scales_offset_X = scales_offset_X_rowwise; const size_t scale_idx_global = scales_offset_Y * scale_stride + scales_offset_X; - // const bool rowwise_scale_is_within_bounds_Y = scales_offset_Y < rows; const bool rowwise_scale_is_within_bounds_Y = (stage_rowwise_scales_offset_Y + it * THREADS_Y_ROWWISE + tid_Y_rowwise) < chunk_rows; if (rowwise_scale_is_within_bounds_X && rowwise_scale_is_within_bounds_Y) { @@ -1162,6 +1319,9 @@ void quantize_transpose(const Tensor &input, const Tensor *noop, Tensor *output, using namespace ptx; bool use_stochastic_rounding = quant_config ? quant_config->stochastic_rounding : false; + const bool row_scaled_nvfp4 = output->row_scaled_nvfp4; + NVTE_CHECK(!row_scaled_nvfp4 || !use_2d_quantization, + "Row-scaled NVFP4 quantization does not support 2D quantization."); // If transposed output is allocated, return the transposed data. Otherwise, it's not necesary to // return the transposed data. @@ -1186,6 +1346,10 @@ void quantize_transpose(const Tensor &input, const Tensor *noop, Tensor *output, NVTE_CHECK(output->has_data(), "NVFP4 output tensor must be allocated."); NVTE_CHECK(is_fp4_dtype(output->data.dtype), "Output must have FP4 type."); NVTE_CHECK(output->scale_inv.dptr != nullptr, "Scaling tensor must be allocated"); + NVTE_CHECK(!row_scaled_nvfp4 || output->amax.dptr != nullptr, + "Row-scaled NVFP4 quantization requires rowwise amax."); + NVTE_CHECK(!row_scaled_nvfp4 || !output->has_columnwise_data(), + "Row-scaled NVFP4 quantization does not produce columnwise output."); NVTE_CHECK(!output->with_gemm_swizzled_scales, "Output must have scales in compact format."); if (return_transpose) { NVTE_CHECK(output->has_columnwise_data(), "NVFP4 transposed output tensor must be allocated."); @@ -1268,20 +1432,23 @@ void quantize_transpose(const Tensor &input, const Tensor *noop, Tensor *output, TRANSFORMER_ENGINE_SWITCH_CONDITION( use_stochastic_rounding, USE_STOCHASTIC_ROUNDING, - TRANSFORMER_ENGINE_SWITCH_CONDITION(return_transpose, RETURN_TRANSPOSE, { - auto kernel = quantize_transpose_nvfp4_kernel; + TRANSFORMER_ENGINE_SWITCH_CONDITION(row_scaled_nvfp4, ROW_SCALED_NVFP4, { + TRANSFORMER_ENGINE_SWITCH_CONDITION(return_transpose, RETURN_TRANSPOSE, { + auto kernel = quantize_transpose_nvfp4_kernel; - if constexpr (use_2d_quantization) { - kernel = quantize_transpose_nvfp4_2D_kernel; - } + if constexpr (use_2d_quantization) { + kernel = quantize_transpose_nvfp4_2D_kernel; + } - cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size); - kernel<<>>( - tensor_map_input, tensor_map_output, tensor_map_output_transpose, scales_ptr, - scales_transpose_ptr, noop_ptr, amax_rowwise_ptr, amax_colwise_ptr, rows, cols, - scale_stride, scale_stride_transpose, rng_state); + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size); + kernel<<>>( + tensor_map_input, tensor_map_output, tensor_map_output_transpose, scales_ptr, + scales_transpose_ptr, noop_ptr, amax_rowwise_ptr, amax_colwise_ptr, rows, cols, + scale_stride, scale_stride_transpose, rng_state); + }); });); #else NVTE_ERROR("FP4 support requires CUDA 12.8+, but compile-time CUDA version is ", CUDA_VERSION); diff --git a/transformer_engine/common/cast/nvfp4/specialized/quantize_transpose_nvfp4_tuned_1D.cuh b/transformer_engine/common/cast/nvfp4/specialized/quantize_transpose_nvfp4_tuned_1D.cuh index fc337f6078..8adda82131 100644 --- a/transformer_engine/common/cast/nvfp4/specialized/quantize_transpose_nvfp4_tuned_1D.cuh +++ b/transformer_engine/common/cast/nvfp4/specialized/quantize_transpose_nvfp4_tuned_1D.cuh @@ -261,14 +261,12 @@ __device__ __forceinline__ void colwise_scaling(const IType *__restrict__ sIn_pt } } -template -__device__ __forceinline__ void rowwise_scaling(const IType *__restrict__ sIn_ptr, - fp4e2m1x2 *__restrict__ sOut_ptr, - nvfp4_scale_t *__restrict__ sSFrowwise_ptr, - const float S_enc_rowwise, const int stage_Y, - const int stage_X, const int buff_in, - const int buff_out, RNG_t &rng, uint4 &random_uint4, - int &rnd_idx) { +template +__device__ __forceinline__ void rowwise_scaling( + const IType *__restrict__ sIn_ptr, fp4e2m1x2 *__restrict__ sOut_ptr, + nvfp4_scale_t *__restrict__ sSFrowwise_ptr, const float S_enc_rowwise, const int stage_Y, + const int stage_X, const int buff_in, const int buff_out, const float *amax_rowwise_ptr, + const size_t row_offset, const size_t rows, RNG_t &rng, uint4 &random_uint4, int &rnd_idx) { using scaling_coeff_type = typename SCALING_COEFFICIENT_TYPE::type; const auto &sIn = *reinterpret_cast(sIn_ptr); @@ -315,9 +313,21 @@ __device__ __forceinline__ void rowwise_scaling(const IType *__restrict__ sIn_pt } const float block_amax = get_amax_of_pair(thread_amax_2x); - const nvfp4_scale_t S_dec_b_fp8 = compute_decoding_scaling_factor(block_amax, S_enc_rowwise); - const scaling_coeff_type SFcoefficient = - compute_nvfp4_scaling_coefficient(S_dec_b_fp8, S_enc_rowwise); + nvfp4_scale_t S_dec_b_fp8; + scaling_coeff_type SFcoefficient; + if constexpr (ROW_SCALED_NVFP4) { + const size_t row_idx = row_offset + stage_Y * TILE_DIM_Y + it_offset_Y_rowwise; + const float S_enc_rowwise_block = + row_idx < rows ? core::compute_global_encode_scaling_factor_FP4(amax_rowwise_ptr[row_idx]) + : 1.0f; + S_dec_b_fp8 = compute_decoding_scaling_factor(block_amax, S_enc_rowwise_block); + SFcoefficient = + compute_nvfp4_scaling_coefficient(S_dec_b_fp8, S_enc_rowwise_block); + } else { + S_dec_b_fp8 = compute_decoding_scaling_factor(block_amax, S_enc_rowwise); + SFcoefficient = + compute_nvfp4_scaling_coefficient(S_dec_b_fp8, S_enc_rowwise); + } // Store scaling factors to SMEM buffer (R2S) if (SF_storing_thread) { @@ -350,7 +360,8 @@ __device__ __forceinline__ void rowwise_scaling(const IType *__restrict__ sIn_pt } } -template +template __global__ void __launch_bounds__(THREADS_NUM) quantize_transpose_nvfp4_tuned_1D_kernel( const __grid_constant__ CUtensorMap tensor_map_input, const __grid_constant__ CUtensorMap tensor_map_output, @@ -571,9 +582,9 @@ __global__ void __launch_bounds__(THREADS_NUM) quantize_transpose_nvfp4_tuned_1D ptx::cp_async_bulk_wait_group_read(); // NVFP4 Quantization - rowwise_scaling( + rowwise_scaling( sIn_ptr, sOut_ptr, sSFrowwise_ptr, S_enc_rowwise, stage_Y, stage_X, buff_in, buff_out, - rng, random_uint4, rnd_idx); + amax_rowwise_ptr, block_offset_Y, rows, rng, random_uint4, rnd_idx); if constexpr (RETURN_TRANSPOSE) { colwise_scaling( @@ -680,6 +691,7 @@ inline void quantize_transpose_tuned_1D(const Tensor &input, const Tensor *noop, const bool use_stochastic_rounding = quant_config ? quant_config->stochastic_rounding : false; const bool use_fast_math = quant_config ? quant_config->use_fast_math : false; + const bool row_scaled_nvfp4 = output->row_scaled_nvfp4; // If transposed output is allocated, return the transposed data // Otherwise, it's not necesary to return the transposed data. @@ -694,6 +706,10 @@ inline void quantize_transpose_tuned_1D(const Tensor &input, const Tensor *noop, NVTE_CHECK(output->has_data(), "NVFP4 output tensor must be allocated."); NVTE_CHECK(is_fp4_dtype(output->data.dtype), "Output must have FP4 type."); NVTE_CHECK(output->scale_inv.dptr != nullptr, "Scaling tensor must be allocated"); + NVTE_CHECK(!row_scaled_nvfp4 || output->amax.dptr != nullptr, + "Row-scaled NVFP4 quantization requires rowwise amax."); + NVTE_CHECK(!row_scaled_nvfp4 || !output->has_columnwise_data(), + "Row-scaled NVFP4 quantization does not produce columnwise output."); if (return_transpose) { NVTE_CHECK(is_fp4_dtype(output->columnwise_data.dtype), @@ -783,16 +799,20 @@ inline void quantize_transpose_tuned_1D(const Tensor &input, const Tensor *noop, use_stochastic_rounding, USE_STOCHASTIC_ROUNDING, TRANSFORMER_ENGINE_SWITCH_CONDITION( use_fast_math, USE_FAST_MATH, - TRANSFORMER_ENGINE_SWITCH_CONDITION(return_transpose, RETURN_TRANSPOSE, { - auto kernel = quantize_transpose_nvfp4_tuned_1D_kernel; - - cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size); - kernel<<>>( - tensor_map_input, tensor_map_output, tensor_map_output_transpose, scales_ptr, - scales_transpose_ptr, noop_ptr, amax_rowwise_ptr, amax_colwise_ptr, rows, cols, - scale_stride, scale_stride_transpose, rng_state); - }););); + TRANSFORMER_ENGINE_SWITCH_CONDITION( + row_scaled_nvfp4, ROW_SCALED_NVFP4, + TRANSFORMER_ENGINE_SWITCH_CONDITION(return_transpose, RETURN_TRANSPOSE, { + auto kernel = + quantize_transpose_nvfp4_tuned_1D_kernel; + + cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, + dshmem_size); + kernel<<>>( + tensor_map_input, tensor_map_output, tensor_map_output_transpose, scales_ptr, + scales_transpose_ptr, noop_ptr, amax_rowwise_ptr, amax_colwise_ptr, rows, cols, + scale_stride, scale_stride_transpose, rng_state); + });););); #else NVTE_ERROR("FP4 support requires CUDA 12.8+, but compile-time CUDA version is ", CUDA_VERSION); #endif // FP4_TYPE_SUPPORTED diff --git a/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp b/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp index 133f1a09e6..28218e2b43 100644 --- a/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp +++ b/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp @@ -222,6 +222,14 @@ TensorWrapper CommOverlapCore::get_tensor_chunk(const TensorWrapper &source, siz TensorWrapper chunk(scaling_mode); for (int param_id = 0; param_id < NVTETensorParam::kNVTENumTensorParams; param_id++) { auto param_type = static_cast(param_id); + if (param_type == NVTETensorParam::kNVTEWithGEMMSwizzledScales) { + chunk.set_with_gemm_swizzled_scales(source.get_with_gemm_swizzled_scales()); + continue; + } + if (param_type == NVTETensorParam::kNVTERowScaledNVFP4) { + chunk.set_row_scaled_nvfp4(source.get_row_scaled_nvfp4()); + continue; + } auto param = source.get_parameter(param_type); auto param_dptr = reinterpret_cast(param.data_ptr); auto param_dtype = static_cast(param.dtype); diff --git a/transformer_engine/common/common.h b/transformer_engine/common/common.h index c1b3f8f427..12479f2a9c 100644 --- a/transformer_engine/common/common.h +++ b/transformer_engine/common/common.h @@ -173,6 +173,11 @@ struct Tensor { * Only meaningful for MXFP8 and NVFP4. */ bool with_gemm_swizzled_scales = false; + /*! \brief Whether NVFP4 rowwise amax metadata is row-scaled. + * + * Only meaningful for NVFP4 tensors. + */ + bool row_scaled_nvfp4 = false; /*! Map from NVTETensorParam to parameter sizes */ static constexpr size_t attr_sizes[] = { @@ -183,7 +188,8 @@ struct Tensor { sizeof(NVTEBasicTensor), // kNVTERowwiseScaleInv sizeof(NVTEBasicTensor), // kNVTEColumnwiseScaleInv sizeof(NVTEBasicTensor), // kNVTEColumnwiseAmax - sizeof(uint8_t) // kNVTEWithGEMMSwizzledScales + sizeof(uint8_t), // kNVTEWithGEMMSwizzledScales + sizeof(uint8_t) // kNVTERowScaledNVFP4 }; Tensor() : scaling_mode{NVTE_DELAYED_TENSOR_SCALING}, nvte_tensor{0} {} @@ -199,6 +205,7 @@ struct Tensor { columnwise_scale_inv.clear(); scaling_mode = NVTE_DELAYED_TENSOR_SCALING; with_gemm_swizzled_scales = false; + row_scaled_nvfp4 = false; } explicit operator NVTETensor() const noexcept { return nvte_tensor; } diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index 144aea1a07..8589d7045d 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -318,6 +318,9 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, const void *alpha, const void *beta, bool use_split_accumulator, int math_sm_count, int m_split, int n_split, bool gemm_producer, const Tensor *inputCounter, cudaStream_t stream) { + NVTE_CHECK(!inputA->row_scaled_nvfp4 && !inputB->row_scaled_nvfp4, + "cuBLAS GEMM does not support row-scaled NVFP4 inputs."); + // Tensor dims in row-major order const int A0 = inputA->flat_first_dim(); const int A1 = inputA->flat_last_dim(); diff --git a/transformer_engine/common/include/transformer_engine/gemm.h b/transformer_engine/common/include/transformer_engine/gemm.h index bf9394c988..9fe692dd2d 100644 --- a/transformer_engine/common/include/transformer_engine/gemm.h +++ b/transformer_engine/common/include/transformer_engine/gemm.h @@ -440,7 +440,7 @@ void nvte_grouped_bias_add(const NVTEGroupedTensor output, const NVTEGroupedTens /*! \brief Grouped Scaled Bias add for grouped GEMM outputs. * * output[row,col] += bias[col] * scale[row], where biases are per-group -* and scales are per-token (per-row across all groups). +* and scales are per-row across all groups. * Requires uniform last-dimension across all output tensors and bias tensors. */ void nvte_grouped_scaled_bias_add(const NVTEGroupedTensor output, const NVTEGroupedTensor bias, diff --git a/transformer_engine/common/include/transformer_engine/transformer_engine.h b/transformer_engine/common/include/transformer_engine/transformer_engine.h index b7461a85d1..e9a6f4f735 100644 --- a/transformer_engine/common/include/transformer_engine/transformer_engine.h +++ b/transformer_engine/common/include/transformer_engine/transformer_engine.h @@ -72,6 +72,7 @@ enum NVTETensorParam { kNVTEColumnwiseScaleInv = 5, /*!< Scale inverse tensor for decoding Columnwise Data */ kNVTEColumnwiseAmax = 6, /*!< Columnwise Amax tensor */ kNVTEWithGEMMSwizzledScales = 7, /*!< Whether scaling factors are in format expected by GEMM */ + kNVTERowScaledNVFP4 = 8, /*!< Whether an NVFP4 tensor uses row scaling */ kNVTENumTensorParams }; @@ -765,6 +766,11 @@ class TensorWrapper { nvte_set_tensor_param_v2(tensor_, kNVTEWithGEMMSwizzledScales, &val, sizeof(val)); } + void set_row_scaled_nvfp4(bool row_scaled_nvfp4) { + const auto val = static_cast(row_scaled_nvfp4); + nvte_set_tensor_param_v2(tensor_, kNVTERowScaledNVFP4, &val, sizeof(val)); + } + // Parameter getters NVTEBasicTensor get_parameter(const NVTETensorParam param) const noexcept { @@ -801,6 +807,12 @@ class TensorWrapper { return static_cast(val); } + bool get_row_scaled_nvfp4() const { + uint8_t val = 0; + nvte_get_tensor_param_v2(tensor_, kNVTERowScaledNVFP4, &val, sizeof(val), nullptr); + return static_cast(val); + } + /*! \brief Get an underlying NVTETensor. * * \return NVTETensor held by this TensorWrapper. diff --git a/transformer_engine/common/recipe/__init__.py b/transformer_engine/common/recipe/__init__.py index 67b6f87067..0d0b2fd37f 100644 --- a/transformer_engine/common/recipe/__init__.py +++ b/transformer_engine/common/recipe/__init__.py @@ -478,6 +478,10 @@ class NVFP4BlockScaling(Recipe): If set to `True`, stochastic rounding is disabled during quantization for all tensors. disable_2d_quantization : bool, default = False If set to `True`, 1D block scaling with block size 16 is used for all tensors. + row_scaled_activation : bool, default = False + If set to `True`, forward activation quantizers emit row-scaled + NVFP4 tensors. In this mode, rowwise ``amax`` metadata is stored + as a vector with one FP32 value per tensor row. backward_override : {None, 'high_precision', 'dequantized'}, default = None Backward precision mode. None does not modify backward behavior, `high_precision` keeps original high-precision operands for backward, @@ -491,6 +495,7 @@ class NVFP4BlockScaling(Recipe): os.getenv("NVTE_NVFP4_DISABLE_STOCHASTIC_ROUNDING", "0") == "1" ) disable_2d_quantization: bool = os.getenv("NVTE_NVFP4_DISABLE_2D_QUANTIZATION", "0") == "1" + row_scaled_activation: bool = os.getenv("NVTE_NVFP4_ROW_SCALED_ACTIVATION", "0") == "1" fp4_format: Format = Format.E2M1 fp8_format: Format = Format.E4M3 @@ -534,6 +539,7 @@ def __repr__(self) -> str: f"fp8_dpa={self.fp8_dpa}, " f"fp8_mha={self.fp8_mha}, " f"backward_override={self.backward_override}, " + f"row_scaled_activation={self.row_scaled_activation}, " f"fp4_quant_fwd_inp={self.fp4_quant_fwd_inp}, " f"fp4_quant_fwd_weight={self.fp4_quant_fwd_weight}, " f"fp4_quant_bwd_grad={self.fp4_quant_bwd_grad}, " diff --git a/transformer_engine/common/transformer_engine.cpp b/transformer_engine/common/transformer_engine.cpp index 1261879a8b..1a52d76019 100644 --- a/transformer_engine/common/transformer_engine.cpp +++ b/transformer_engine/common/transformer_engine.cpp @@ -852,6 +852,9 @@ void nvte_set_tensor_param_v2(NVTETensor tensor, NVTETensorParam param, const vo case kNVTEWithGEMMSwizzledScales: t.with_gemm_swizzled_scales = static_cast(*reinterpret_cast(buf)); break; + case kNVTERowScaledNVFP4: + t.row_scaled_nvfp4 = static_cast(*reinterpret_cast(buf)); + break; default: NVTE_ERROR("Unsupported tensor parameter (", static_cast(param), ")"); } @@ -932,6 +935,9 @@ void nvte_get_tensor_param_v2(const NVTETensor tensor, NVTETensorParam param, vo case kNVTEWithGEMMSwizzledScales: *reinterpret_cast(buf) = static_cast(t->with_gemm_swizzled_scales); break; + case kNVTERowScaledNVFP4: + *reinterpret_cast(buf) = static_cast(t->row_scaled_nvfp4); + break; default: NVTE_ERROR("Unsupported tensor parameter (", static_cast(param), ")"); } diff --git a/transformer_engine/common/transpose/cast_transpose.h b/transformer_engine/common/transpose/cast_transpose.h index a5ec2306b1..c462b30147 100644 --- a/transformer_engine/common/transpose/cast_transpose.h +++ b/transformer_engine/common/transpose/cast_transpose.h @@ -67,7 +67,7 @@ void quantize_transpose_vector_blockwise_fp4( SimpleTensor &scale_inv_t, SimpleTensor &output, SimpleTensor &output_t, const float epsilon, const bool return_identity, const bool return_transpose, const bool pow2_scale, const bool swizzled_scale, const bool use_stochastic_rounding, - const NVTETensor rng_state_tensor, const bool use_2d_quantization, + const NVTETensor rng_state_tensor, const bool use_2d_quantization, const bool row_scaled_nvfp4, const SimpleTensor &noop_tensor, cudaStream_t stream); } // namespace transformer_engine::detail diff --git a/transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu b/transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu index d3d3dceca9..cf9821f1a9 100644 --- a/transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu +++ b/transformer_engine/common/transpose/quantize_transpose_vector_blockwise_fp4.cu @@ -316,7 +316,7 @@ __device__ __forceinline__ __nv_fp4x4_e2m1 cvt_fp32_to_fp4_4x(const float2 in01, template + bool kApplyStochasticRounding, bool kIs2DBlockScaling, bool kRowScaledNVFP4> __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpose_kernel( const IType* const input, const float* global_amax, OType* const output_c, OType* const output_t, ScaleType* const tile_scales_inv_c, ScaleType* const tile_scales_inv_t, @@ -509,8 +509,19 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo amax = amax_smem[data_row_idx / kFP4BlockScalingSize][tid_in_warp_x]; } // Step 2.4: Compute scale - ScaleType scale_inv = ComputeDecodeScaleFP4(amax, global_encode_scale_multiplier); - float encode_scale = ComputeEncodeScaleFP4(scale_inv, global_decode_scale); + const size_t row_idx = block_idx_y * kTileDim + r_s; + float row_global_encode_scale = global_encode_scale; + if constexpr (kRowScaledNVFP4) { + row_global_encode_scale = + row_idx < num_rows ? ComputeGlobalEncodeScaleFP4(global_amax[row_idx]) : 1.0f; + } + const float row_global_encode_scale_multiplier = + kRowScaledNVFP4 ? row_global_encode_scale * fp4_max_inv : global_encode_scale_multiplier; + const float row_global_decode_scale = + kRowScaledNVFP4 ? 1.0f / row_global_encode_scale : global_decode_scale; + ScaleType scale_inv = + ComputeDecodeScaleFP4(amax, row_global_encode_scale_multiplier); + float encode_scale = ComputeEncodeScaleFP4(scale_inv, row_global_decode_scale); // Step 2.5: Write scale_inv bool write_scale_inv = is_src_lane; if constexpr (!kAligned) { @@ -708,7 +719,7 @@ void quantize_transpose_vector_blockwise_fp4( SimpleTensor& scale_inv_t, SimpleTensor& output, SimpleTensor& output_t, const float epsilon, const bool return_identity, const bool return_transpose, const bool pow2_scale, const bool swizzled_scale, const bool use_stochastic_rounding, - const NVTETensor rng_state_tensor, const bool use_2d_quantization, + const NVTETensor rng_state_tensor, const bool use_2d_quantization, const bool row_scaled_nvfp4, const SimpleTensor& noop_tensor, cudaStream_t stream) { NVTE_API_CALL(quantize_transpose_vector_blockwise_fp4); #if CUDA_VERSION >= 12080 @@ -722,6 +733,10 @@ void quantize_transpose_vector_blockwise_fp4( NVTE_CHECK(return_identity || !use_2d_quantization, "2D block quantization is only supported when return_identity is true."); + NVTE_CHECK(!row_scaled_nvfp4 || (return_identity && !return_transpose), + "Row-scaled NVFP4 quantization only supports rowwise quantization."); + NVTE_CHECK(!row_scaled_nvfp4 || !use_2d_quantization, + "Row-scaled NVFP4 quantization does not support 2D quantization."); const size_t row_length = input.shape.size() > 0 ? input.shape.at(input.shape.size() - 1) : 1u; size_t num_elements = row_length; @@ -801,35 +816,41 @@ void quantize_transpose_vector_blockwise_fp4( TRANSFORMER_ENGINE_SWITCH_CONDITION( use_2d_quantization, kIs2DBlockScaling, - size_t smem_bytes = kSMemSize * sizeof(InputType); - auto kernel = block_scaled_1d_cast_transpose_kernel< - kReturnIdentity, kReturnTranspose, kPow2Scale, kAligned, - float, InputType, OutputType, ScaleType, kSwizzledScale, - kApplyStochasticRounding, kIs2DBlockScaling>; - if (smem_bytes >= 48 * 1024) { - cudaError_t err = cudaFuncSetAttribute( - kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, - smem_bytes); - NVTE_CHECK(err == cudaSuccess, - "Failed to set dynamic shared memory size."); - } kernel<<>>( - reinterpret_cast(input.dptr), - reinterpret_cast(global_amax.dptr), - reinterpret_cast(output.dptr), - reinterpret_cast(output_t.dptr), - reinterpret_cast(scale_inv.dptr), - reinterpret_cast(scale_inv_t.dptr), row_length, - num_rows, scale_stride_x, scale_stride_y, scale_t_stride_x, - scale_t_stride_y, kScaleBlockDim, epsilon, rng_state, - noop_ptr);) // kIs2DBlockScaling - ) // kApplyStochasticRounding - ) // kSwizzledScale - ) // kAligned - ) // kReturnTranspose - ) // kReturnIdentity - ) // OutputType - ) // InputType + TRANSFORMER_ENGINE_SWITCH_CONDITION( + row_scaled_nvfp4, kRowScaledNVFP4, + + size_t smem_bytes = kSMemSize * sizeof(InputType); + auto kernel = block_scaled_1d_cast_transpose_kernel< + kReturnIdentity, kReturnTranspose, kPow2Scale, kAligned, + float, InputType, OutputType, ScaleType, kSwizzledScale, + kApplyStochasticRounding, kIs2DBlockScaling, + kRowScaledNVFP4>; + if (smem_bytes >= 48 * 1024) { + cudaError_t err = cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_bytes); + NVTE_CHECK(err == cudaSuccess, + "Failed to set dynamic shared memory size."); + } kernel<<>>( + reinterpret_cast(input.dptr), + reinterpret_cast(global_amax.dptr), + reinterpret_cast(output.dptr), + reinterpret_cast(output_t.dptr), + reinterpret_cast(scale_inv.dptr), + reinterpret_cast(scale_inv_t.dptr), + row_length, num_rows, scale_stride_x, scale_stride_y, + scale_t_stride_x, scale_t_stride_y, kScaleBlockDim, + epsilon, rng_state, + noop_ptr);) // kRowScaledNVFP4 + ) // kIs2DBlockScaling + ) // kApplyStochasticRounding + ) // kSwizzledScale + ) // kAligned + ) // kReturnTranspose + ) // kReturnIdentity + ) // OutputType + ) // InputType NVTE_CHECK_CUDA(cudaGetLastError()); #else diff --git a/transformer_engine/pytorch/cpp_extensions/gemm.py b/transformer_engine/pytorch/cpp_extensions/gemm.py index 6f3553bf94..edf2c1e1c2 100644 --- a/transformer_engine/pytorch/cpp_extensions/gemm.py +++ b/transformer_engine/pytorch/cpp_extensions/gemm.py @@ -15,6 +15,8 @@ from ..quantized_tensor import Quantizer from ..tensor.storage.float8_blockwise_tensor_storage import Float8BlockwiseQTensorStorage +from ..tensor.storage.grouped_tensor_storage import GroupedTensorStorage +from ..tensor.storage.nvfp4_tensor_storage import NVFP4TensorStorage from ..tensor.utils import is_custom from ..custom_recipes.gemm import custom_gemm from ...debug.pytorch.debug_quantization import DebugQuantizer @@ -69,6 +71,38 @@ def validate_gemm_scale(scale: Optional[float], required: bool) -> float: return 0.0 +def _is_nvfp4_row_scaled_tensor(tensor: torch.Tensor) -> bool: + """Whether tensor carries row-scaled NVFP4 global amax metadata.""" + return isinstance(tensor, NVFP4TensorStorage) and tensor._row_scaled_nvfp4 + + +def _nvfp4_row_scaled_gemm_inputs( + A: NVFP4TensorStorage, + B: NVFP4TensorStorage, + *, + transa: bool, +) -> Tuple[NVFP4TensorStorage, NVFP4TensorStorage, torch.Tensor]: + """Return GEMM aliases and FP32 output scales for row-scaled NVFP4.""" + A_metadata = A.get_metadata() + weight_amax = A._amax_rowwise if transa else A._amax_columnwise + assert weight_amax is not None and weight_amax.numel() == 1 + A_metadata["amax_rowwise" if transa else "amax_columnwise"] = weight_amax.new_ones(1) + A_metadata["row_scaled_nvfp4"] = False + + B_metadata = B.get_metadata() + rhs_rowwise_amax = B._amax_rowwise + assert rhs_rowwise_amax is not None + B_metadata["amax_rowwise"] = rhs_rowwise_amax.new_ones(1) + B_metadata["row_scaled_nvfp4"] = False + + assert rhs_rowwise_amax.dtype == torch.float32 and weight_amax.dtype == torch.float32 + return ( + NVFP4TensorStorage(**A_metadata), + NVFP4TensorStorage(**B_metadata), + (rhs_rowwise_amax * weight_amax).view(-1, 1), + ) + + def general_gemm( A: torch.Tensor, B: torch.Tensor, @@ -174,7 +208,65 @@ def general_gemm( "beta": beta, } - out, bias_grad, gelu_input, extra_output = tex.generic_gemm(*args, **kwargs) + if not _is_nvfp4_row_scaled_tensor(A) and not _is_nvfp4_row_scaled_tensor(B): + out, bias_grad, gelu_input, extra_output = tex.generic_gemm(*args, **kwargs) + else: + if _is_nvfp4_row_scaled_tensor(A): + raise NotImplementedError("Row-scaled NVFP4 GEMM does not support row-scaled A.") + assert layout[1] == "N", "Row-scaled NVFP4 GEMM currently supports N-layout B only." + if grad: + raise RuntimeError( + "Row-scaled NVFP4 GEMM currently supports fprop only. " + "Backward NVFP4 gradient quantizers should use scalar global amax." + ) + assert not gelu, "Row-scaled NVFP4 GEMM currently does not support fused GELU." + assert not accumulate, "Row-scaled NVFP4 GEMM currently does not support accumulation." + assert ( + quantization_params is None + ), "Row-scaled NVFP4 GEMM currently does not support output quantization." + assert ub is None, "Row-scaled NVFP4 GEMM currently does not support CommOverlap." + assert ( + extra_output is None + ), "Row-scaled NVFP4 GEMM currently does not support extra output." + assert not bulk_overlap, "Row-scaled NVFP4 GEMM currently does not support bulk overlap." + assert out is None or ( + isinstance(out, torch.Tensor) and not is_custom(out) + ), "Row-scaled NVFP4 GEMM currently supports only plain torch.Tensor outputs." + assert isinstance( + A, NVFP4TensorStorage + ), "Row-scaled NVFP4 GEMM currently requires NVFP4 A." + # cuBLAS folds NVFP4 global amax values into GEMM alpha. Keep the row-scaled + # recipe's global scales out of alpha and apply them in FP32 below. + gemm_A, gemm_B, rowwise_global_scales = _nvfp4_row_scaled_gemm_inputs(A, B, transa=transa) + + requested_out, requested_out_dtype = out, out_dtype + fp32_out = ( + torch.empty_like(requested_out, dtype=torch.float32) + if requested_out is not None + else None + ) + gemm_args = list(args) + gemm_args[0] = gemm_A # A + gemm_args[2] = gemm_B # B + gemm_args[4] = fp32_out # out + gemm_args[5] = None # quantization_params + gemm_args[6] = TE_DType[torch.float32] # out_dtype + gemm_args[7] = None # bias + out, bias_grad, gelu_input, extra_output = tex.generic_gemm(*gemm_args, **kwargs) + out_2d = out.reshape(-1, out.shape[-1]) + + assert rowwise_global_scales.dtype == torch.float32 and out.dtype == torch.float32 + assert rowwise_global_scales.numel() == out_2d.shape[0] + + out_2d.mul_(rowwise_global_scales) + if bias is not None: + out_2d.add_(bias.to(dtype=torch.float32)) + + if requested_out is not None: + requested_out.copy_(out.to(dtype=requested_out.dtype)) + out = requested_out + elif requested_out_dtype is not None and requested_out_dtype != torch.float32: + out = out.to(dtype=requested_out_dtype) if debug_quantizer is not None: out = debug_quantizer.process_gemm_output(out) @@ -229,6 +321,44 @@ def general_grouped_gemm( else: bias_dtype = TE_DType[torch.bfloat16] + if any(_is_nvfp4_row_scaled_tensor(tensor) for tensor in A): + raise NotImplementedError("Row-scaled NVFP4 grouped GEMM does not support row-scaled A.") + if any(_is_nvfp4_row_scaled_tensor(tensor) for tensor in B): + assert D_dtype is None, "Row-scaled NVFP4 grouped GEMM currently does not support D_dtype." + if single_output: + assert ( + m_splits is not None + ), "Row-scaled NVFP4 grouped GEMM requires m_splits with single output." + out_init = out[0] if single_output else None + if single_output: + start_idx = 0 + out_views = [] + for i in range(num_gemms): + size = m_splits[i] + out_views.append(out_init[start_idx : start_idx + size]) + start_idx += size + else: + out_views = out + for i in range(num_gemms): + if out_views[i].numel() == 0: + continue + general_gemm( + A[i], + B[i], + quantization_params=quantization_params[i], + out_dtype=out_views[i].dtype, + out=out_views[i], + gelu=gelu, + accumulate=accumulate, + layout=layout, + bias=bias[i] if use_bias else None, + use_split_accumulator=use_split_accumulator, + grad=grad, + ) + if single_output: + out = out_init + return out, grad_bias, gelu_input + if isinstance(quantization_params[0], DebugQuantizer): assert not gelu, "GELU not supported in debug mode" if single_output: @@ -350,6 +480,13 @@ def general_grouped_gemm_for_grouped_tensor( if is_discrete_in and is_discrete_out: raise ValueError("Both A and out are discrete. This is not supported yet.") + if isinstance(A, GroupedTensorStorage) and A.row_scaled_nvfp4: + raise NotImplementedError("Row-scaled NVFP4 GroupedTensor GEMM is not supported yet.") + if isinstance(B, GroupedTensorStorage) and B.row_scaled_nvfp4: + raise NotImplementedError("Row-scaled NVFP4 GroupedTensor GEMM is not supported yet.") + if isinstance(out, GroupedTensorStorage) and out.row_scaled_nvfp4: + raise NotImplementedError("Row-scaled NVFP4 GroupedTensor GEMM is not supported yet.") + if is_discrete_out: # wgrad case. grouped_gemm_impl = tex.te_general_grouped_gemm_for_discrete_out diff --git a/transformer_engine/pytorch/csrc/common.h b/transformer_engine/pytorch/csrc/common.h index 8e3bcdd5b3..8f5b8294e8 100644 --- a/transformer_engine/pytorch/csrc/common.h +++ b/transformer_engine/pytorch/csrc/common.h @@ -320,6 +320,8 @@ class NVFP4Quantizer : public Quantizer { // 2D block scaling bool with_2d_quantization; bool stochastic_rounding; + // Whether tensors emitted by this quantizer use row-scaled NVFP4 metadata. + bool row_scaled_nvfp4; int rht_matrix_random_sign_mask_t; at::Tensor rht_matrix; diff --git a/transformer_engine/pytorch/csrc/extensions/activation.cpp b/transformer_engine/pytorch/csrc/extensions/activation.cpp index 2df3b66553..cab9fab30a 100644 --- a/transformer_engine/pytorch/csrc/extensions/activation.cpp +++ b/transformer_engine/pytorch/csrc/extensions/activation.cpp @@ -42,8 +42,9 @@ py::object activation_helper(const at::Tensor& input, py::handle quantizer, int } else if (detail::IsNVFP4Quantizers(quantizer.ptr())) { auto nvfp4_quantizer_cpp = dynamic_cast(quantizer_cpp.get()); NVTE_CHECK(nvfp4_quantizer_cpp != nullptr, "Could not cast to NVFP4 quantizer"); - if (nvfp4_quantizer_cpp->with_rht && nvfp4_quantizer_cpp->with_post_rht_amax) { - // Post-RHT amax is handled within NVFP4 quantizer + if (nvfp4_quantizer_cpp->row_scaled_nvfp4 || + (nvfp4_quantizer_cpp->with_rht && nvfp4_quantizer_cpp->with_post_rht_amax)) { + // Amax is handled within NVFP4 quantizer impl = Impl::UNFUSED; } else { impl = Impl::FUSED_ACTIVATION_AMAX_NVFP4; @@ -154,8 +155,9 @@ py::object dactivation_helper(const at::Tensor& grad_output, const at::Tensor& i } else if (detail::IsNVFP4Quantizers(quantizer.ptr())) { auto nvfp4_quantizer_cpp = dynamic_cast(quantizer_cpp.get()); NVTE_CHECK(nvfp4_quantizer_cpp != nullptr, "Could not cast to NVFP4 quantizer"); - if (nvfp4_quantizer_cpp->with_rht && nvfp4_quantizer_cpp->with_post_rht_amax) { - // Post-RHT amax is handled within NVFP4 quantizer + if (nvfp4_quantizer_cpp->row_scaled_nvfp4 || + (nvfp4_quantizer_cpp->with_rht && nvfp4_quantizer_cpp->with_post_rht_amax)) { + // Amax is handled within NVFP4 quantizer impl = Impl::UNFUSED; } else { impl = Impl::FUSED_ACTIVATION_AMAX_NVFP4; diff --git a/transformer_engine/pytorch/csrc/extensions/bias.cpp b/transformer_engine/pytorch/csrc/extensions/bias.cpp index 0cf2025f1b..4a78dde388 100644 --- a/transformer_engine/pytorch/csrc/extensions/bias.cpp +++ b/transformer_engine/pytorch/csrc/extensions/bias.cpp @@ -152,8 +152,9 @@ std::vector dact_dbias( } else if (detail::IsNVFP4Quantizers(quantizer_py.ptr())) { auto nvfp4_quantizer_cpp = dynamic_cast(quantizer_cpp.get()); NVTE_CHECK(nvfp4_quantizer_cpp != nullptr, "Could not cast to NVFP4 quantizer"); - if (nvfp4_quantizer_cpp->with_rht && nvfp4_quantizer_cpp->with_post_rht_amax) { - // Post-RHT amax is handled within NVFP4 quantizer + if (nvfp4_quantizer_cpp->row_scaled_nvfp4 || + (nvfp4_quantizer_cpp->with_rht && nvfp4_quantizer_cpp->with_post_rht_amax)) { + // Amax is handled within NVFP4 quantizer impl = Impl::UNFUSED; } else { impl = Impl::FUSED_DACT_AMAX_NVFP4; diff --git a/transformer_engine/pytorch/csrc/extensions/cast.cpp b/transformer_engine/pytorch/csrc/extensions/cast.cpp index 50fe4c109e..9e1f381bfe 100644 --- a/transformer_engine/pytorch/csrc/extensions/cast.cpp +++ b/transformer_engine/pytorch/csrc/extensions/cast.cpp @@ -798,7 +798,13 @@ std::tuple, std::vector, bool> bulk_alloc // Quantization parameters const auto rowwise_usage = quantizer_cpp_list[0]->rowwise_usage; + const bool row_scaled_nvfp4 = quantizer_cpp_list[0]->row_scaled_nvfp4; const auto columnwise_usage = quantizer_cpp_list[0]->columnwise_usage; + if (row_scaled_nvfp4) { + NVTE_CHECK(rowwise_usage, "Row-scaled NVFP4 bulk allocation requires rowwise usage."); + NVTE_CHECK(!columnwise_usage, + "Row-scaled NVFP4 bulk allocation does not support columnwise usage."); + } const auto scaling_mode = quantizer_cpp_list[0]->get_scaling_mode(); const auto fp4_dtype = quantizer_cpp_list[0]->dtype; const bool with_gemm_swizzled_scales = false; /// TODO (tmoon) Enable based on optimize_for_gemm; @@ -828,6 +834,16 @@ std::tuple, std::vector, bool> bulk_alloc } return fp4_shape; }; + auto flat_first_dim = [](const std::vector &shape) -> size_t { + if (shape.empty()) { + return 1; + } + size_t rows = 1; + for (size_t i = 0; i + 1 < shape.size(); ++i) { + rows *= shape[i]; + } + return rows; + }; // Allocate row-wise data std::vector rowwise_data_list, rowwise_scale_list, amax_rowwise_list; @@ -866,7 +882,11 @@ std::tuple, std::vector, bool> bulk_alloc // Note: Multi-quantize kernel does not require contiguous amaxes. const auto offset = roundup(buffer_size, 16); amax_offsets.push_back(offset); - buffer_size = offset + 4; + size_t amax_size = 4; + if (row_scaled_nvfp4) { + amax_size *= flat_first_dim(rowwise_data_shapes[i]); + } + buffer_size = offset + amax_size; } // Allocate full buffer @@ -879,8 +899,12 @@ std::tuple, std::vector, bool> bulk_alloc data_offsets[i], torch::kUInt8)); rowwise_scale_list.emplace_back( make_torch_view(buffer, rowwise_scale_shapes[i], scale_offsets[i], torch::kUInt8)); + std::vector amax_shape{1}; + if (row_scaled_nvfp4) { + amax_shape = {flat_first_dim(rowwise_data_shapes[i])}; + } amax_rowwise_list.emplace_back( - make_torch_view(buffer, std::vector{1}, amax_offsets[i], torch::kFloat32)); + make_torch_view(buffer, amax_shape, amax_offsets[i], torch::kFloat32)); } } @@ -960,9 +984,10 @@ std::tuple, std::vector, bool> bulk_alloc py::object amax_columnwise = columnwise_usage ? py::cast(amax_columnwise_list[i]) : py::none(); // Construct Python tensor - tensor_py_list.emplace_back(NVFP4TensorClass( - rowwise_data, rowwise_scale, columnwise_data, columnwise_scale, amax_rowwise, - amax_columnwise, fp4_dtype, quantizer_py_list[i], with_gemm_swizzled_scales)); + tensor_py_list.emplace_back(NVFP4TensorClass(rowwise_data, rowwise_scale, columnwise_data, + columnwise_scale, amax_rowwise, amax_columnwise, + fp4_dtype, quantizer_py_list[i], + with_gemm_swizzled_scales, row_scaled_nvfp4)); // Construct C++ tensor // Use a TensorWrapper variable to hold the output of makeTransformerEngineTensor, @@ -979,11 +1004,12 @@ std::tuple, std::vector, bool> bulk_alloc rowwise_usage ? rowwise_scale_shapes[i] : std::vector{0}, columnwise_usage ? columnwise_scale_shapes[i] : std::vector{0}, scaling_mode); tensor_wrapper.set_with_gemm_swizzled_scales(with_gemm_swizzled_scales); + tensor_wrapper.set_row_scaled_nvfp4(row_scaled_nvfp4); // Set the amax rowwise and amax columnwise if available if (rowwise_usage) { tensor_wrapper.set_amax(amax_rowwise_list[i].data_ptr(), DType::kFloat32, - std::vector{1}); + getTensorShape(amax_rowwise_list[i])); } if (columnwise_usage) { tensor_wrapper.set_columnwise_amax(amax_columnwise_list[i].data_ptr(), DType::kFloat32, @@ -1455,7 +1481,16 @@ std::vector split_quantize(const at::Tensor &tensor, return detail::IsNVFP4Quantizers(quantizer.ptr()); })) { allocation_method = AllocationMethod::BULK_NVFP4; - quantization_method = QuantizationMethod::FUSED_NVFP4; + const bool has_row_scaled_nvfp4 = + std::any_of(quantizer_cpp_list.begin(), quantizer_cpp_list.end(), + [](const std::unique_ptr &quantizer) { + return static_cast(quantizer.get())->row_scaled_nvfp4; + }); + if (has_row_scaled_nvfp4) { + quantization_method = QuantizationMethod::UNFUSED; + } else { + quantization_method = QuantizationMethod::FUSED_NVFP4; + } } } @@ -1492,7 +1527,8 @@ std::vector split_quantize(const at::Tensor &tensor, bool contiguous_data_and_scale = false; std::tie(output_py_list, output_cpp_list, contiguous_data_and_scale) = bulk_allocate_nvfp4_tensors(split_shapes, quantizer_list, nvfp4_quantizers); - if (!input_shape.empty() && input_shape.back() % 128 != 0) { + if (quantization_method == QuantizationMethod::FUSED_NVFP4 && !input_shape.empty() && + input_shape.back() % 128 != 0) { static std::once_flag once_unfused_nvfp4_fallback_warning; std::call_once(once_unfused_nvfp4_fallback_warning, []() { NVTE_WARN( diff --git a/transformer_engine/pytorch/csrc/extensions/normalization.cpp b/transformer_engine/pytorch/csrc/extensions/normalization.cpp index fb4c7aa1c9..4887b59c28 100644 --- a/transformer_engine/pytorch/csrc/extensions/normalization.cpp +++ b/transformer_engine/pytorch/csrc/extensions/normalization.cpp @@ -120,8 +120,9 @@ std::vector layernorm_fwd(py::handle input, py::handle weight, Maybe } else if (detail::IsNVFP4Quantizers(quantizer.ptr())) { auto nvfp4_quantizer_cpp = dynamic_cast(quantizer_cpp.get()); NVTE_CHECK(nvfp4_quantizer_cpp != nullptr, "Could not cast to NVFP4 quantizer"); - if (nvfp4_quantizer_cpp->with_rht && nvfp4_quantizer_cpp->with_post_rht_amax) { - // Post-RHT amax is handled within NVFP4 quantizer + if (nvfp4_quantizer_cpp->row_scaled_nvfp4 || + (nvfp4_quantizer_cpp->with_rht && nvfp4_quantizer_cpp->with_post_rht_amax)) { + // Amax is handled within NVFP4 quantizer impl = Impl::UNFUSED; } else if (!transformer_engine::getenv("NVTE_NORM_FWD_USE_CUDNN")) { // TE kernel supports amax output @@ -357,8 +358,9 @@ std::vector rmsnorm_fwd(const py::handle &input, const py::handle &w } else if (detail::IsNVFP4Quantizers(quantizer.ptr())) { auto nvfp4_quantizer_cpp = dynamic_cast(quantizer_cpp.get()); NVTE_CHECK(nvfp4_quantizer_cpp != nullptr, "Could not cast to NVFP4 quantizer"); - if (nvfp4_quantizer_cpp->with_rht && nvfp4_quantizer_cpp->with_post_rht_amax) { - // Post-RHT amax is handled within NVFP4 quantizer + if (nvfp4_quantizer_cpp->row_scaled_nvfp4 || + (nvfp4_quantizer_cpp->with_rht && nvfp4_quantizer_cpp->with_post_rht_amax)) { + // Amax is handled within NVFP4 quantizer impl = Impl::UNFUSED; } else if (!transformer_engine::getenv("NVTE_NORM_FWD_USE_CUDNN")) { // TE kernel supports amax output diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index da91e5c170..8f2de325ae 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -1696,6 +1696,7 @@ NVFP4Quantizer::NVFP4Quantizer(const py::handle& quantizer) : Quantizer(quantize this->with_post_rht_amax = quantizer.attr("with_post_rht_amax").cast(); this->with_2d_quantization = quantizer.attr("with_2d_quantization").cast(); this->stochastic_rounding = quantizer.attr("stochastic_rounding").cast(); + this->row_scaled_nvfp4 = quantizer.attr("row_scaled_nvfp4").cast(); // Get amax reduction group if needed for NVFP4 AG const bool with_amax_reduction = quantizer.attr("with_amax_reduction").cast(); @@ -1747,6 +1748,12 @@ std::pair NVFP4Quantizer::create_tensor(const std::ve NVTE_CHECK(flat_last_dim % NVFP4_BLOCK_SIZE == 0, "NVFP4 requires tensor dims that are divisible by ", NVFP4_BLOCK_SIZE, " (got shape=", shape, ")"); + const bool row_scaled_nvfp4 = this->row_scaled_nvfp4; + if (row_scaled_nvfp4) { + NVTE_CHECK(rowwise_usage, "Row-scaled NVFP4 quantization requires rowwise usage."); + NVTE_CHECK(!columnwise_usage, + "Row-scaled NVFP4 quantization does not support columnwise usage."); + } const auto rowwise_scale_inv_shape = get_scale_shape(shape, false); const auto columnwise_scale_inv_shape = get_scale_shape(shape, true); @@ -1760,9 +1767,10 @@ std::pair NVFP4Quantizer::create_tensor(const std::ve rowwise_scale_inv_shape.end()); rowwise_data_tensor = at::empty(convert_shape_for_fp4(shape_int64), bit8_tensor_opts); rowwise_scale_inv_tensor = at::empty(scale_inv_shape_int64, bit8_tensor_opts); + const int64_t amax_rows = row_scaled_nvfp4 ? static_cast(flat_first_dim) : 1; // hadamard amax kernel will zero out pointer with ZeroAmaxKernel // nvte_compute_amax_with_config will zero out the pointer if needed - amax_rowwise = at::empty({1}, bit32_tensor_opts); + amax_rowwise = at::empty({amax_rows}, bit32_tensor_opts); } if (columnwise_usage) { const std::vector scale_inv_shape_int64(columnwise_scale_inv_shape.begin(), @@ -1805,6 +1813,7 @@ std::pair NVFP4Quantizer::create_tensor(const std::ve kwargs["fp4_dtype"] = py::cast(this->dtype); kwargs["quantizer"] = this->quantizer; kwargs["with_gemm_swizzled_scales"] = py::cast(with_gemm_swizzled_scales); + kwargs["row_scaled_nvfp4"] = py::cast(row_scaled_nvfp4); kwargs["fake_dtype"] = GetATenDType(dtype); py::tuple args(0); @@ -1833,6 +1842,7 @@ std::pair NVFP4Quantizer::create_tensor(const std::ve kwargs["fp4_dtype"] = py::cast(this->dtype); kwargs["quantizer"] = this->quantizer; kwargs["with_gemm_swizzled_scales"] = py::cast(with_gemm_swizzled_scales); + kwargs["row_scaled_nvfp4"] = py::cast(row_scaled_nvfp4); py::tuple args(0); PyObject* result = PyObject_Call(reinterpret_cast(NVFP4TensorPythonClass), args.ptr(), kwargs.ptr()); @@ -1850,7 +1860,7 @@ std::pair NVFP4Quantizer::create_tensor(const std::ve out_cpp.set_rowwise_data(rowwise_data_tensor.data_ptr(), DType::kFloat4E2M1, shape); out_cpp.set_rowwise_scale_inv(rowwise_scale_inv_tensor.data_ptr(), DType::kFloat8E4M3, rowwise_scale_inv_shape); - out_cpp.set_amax(amax_rowwise.data_ptr(), DType::kFloat32, std::vector{1}); + out_cpp.set_amax(amax_rowwise.data_ptr(), DType::kFloat32, getTensorShape(amax_rowwise)); } if (columnwise_usage) { // enforce 2D shape to avoid [S, B, H] shape and B and be 1 @@ -1865,6 +1875,7 @@ std::pair NVFP4Quantizer::create_tensor(const std::ve std::vector{1}); } out_cpp.set_with_gemm_swizzled_scales(with_gemm_swizzled_scales); + out_cpp.set_row_scaled_nvfp4(row_scaled_nvfp4); this->set_quantization_params(&out_cpp); return {std::move(out_cpp), std::move(out_py)}; @@ -1892,6 +1903,12 @@ std::pair NVFP4Quantizer::create_grouped_tenso std::optional rowwise_amax; std::optional columnwise_amax; const std::vector logical_shape_vec = {logical_first_dim, logical_last_dim}; + const bool row_scaled_nvfp4 = this->row_scaled_nvfp4; + if (row_scaled_nvfp4) { + NVTE_CHECK(rowwise_usage, "Row-scaled NVFP4 grouped quantization requires rowwise usage."); + NVTE_CHECK(!columnwise_usage, + "Row-scaled NVFP4 grouped quantization does not support columnwise usage."); + } const int64_t total_data_elements = total_elements / 2; @@ -1900,7 +1917,9 @@ std::pair NVFP4Quantizer::create_grouped_tenso const auto scale_shape = get_scale_shape(logical_shape_vec, false); const int64_t total_scale_elements = static_cast(product(scale_shape)); rowwise_scale_inv = at::empty({total_scale_elements}, uint8_opts); - rowwise_amax = at::empty({static_cast(num_tensors)}, float_opts); + const int64_t amax_elements = row_scaled_nvfp4 ? static_cast(logical_first_dim) + : static_cast(num_tensors); + rowwise_amax = at::empty({amax_elements}, float_opts); } if (columnwise_usage) { @@ -1958,6 +1977,7 @@ std::pair NVFP4Quantizer::create_grouped_tenso kwargs["last_dims"] = py::none(); kwargs["tensor_offsets"] = tensor_offsets.has_value() ? py::cast(*tensor_offsets) : py::none(); kwargs["with_gemm_swizzled_scales"] = this->optimize_for_gemm; + kwargs["row_scaled_nvfp4"] = py::cast(row_scaled_nvfp4); PyObject* result = PyObject_Call(GroupedTensorClass.ptr(), args.ptr(), kwargs.ptr()); if (result == nullptr) { PyErr_Print(); @@ -1975,15 +1995,22 @@ std::pair NVFP4Quantizer::create_unquantized_tensor_w auto [out_cpp, out_py] = NoneQuantizer(py::none()).create_tensor(shape, dtype); // Register amax pointer from quantized tensor - void* amax_ptr = quantized_tensor.amax(); + auto rowwise_amax = quantized_tensor.get_amax(); + auto columnwise_amax = quantized_tensor.get_columnwise_amax(); + + void* amax_ptr = rowwise_amax.data_ptr; + std::vector amax_shape = convertShape(rowwise_amax.shape); if (amax_ptr == nullptr) { - amax_ptr = quantized_tensor.get_columnwise_amax().data_ptr; + amax_ptr = columnwise_amax.data_ptr; + amax_shape = convertShape(columnwise_amax.shape); } NVTE_CHECK(amax_ptr != nullptr, "Could not extract amax pointer from NVFP4 tensor."); - out_cpp.set_amax(amax_ptr, DType::kFloat32, std::vector{1}); + out_cpp.set_amax(amax_ptr, DType::kFloat32, amax_shape); // Zero out amax - NVTE_CHECK_CUDA(cudaMemsetAsync(amax_ptr, 0, sizeof(float), at::cuda::getCurrentCUDAStream())); + const size_t amax_numel = product(amax_shape); + NVTE_CHECK_CUDA( + cudaMemsetAsync(amax_ptr, 0, amax_numel * sizeof(float), at::cuda::getCurrentCUDAStream())); return {std::move(out_cpp), std::move(out_py)}; } @@ -2031,6 +2058,13 @@ std::pair NVFP4Quantizer::convert_and_update_tensor( } } const size_t flat_last_dim = shape.size() > 0 ? shape.back() : 1; + const bool row_scaled_nvfp4 = this->row_scaled_nvfp4; + if (row_scaled_nvfp4) { + NVTE_CHECK(rowwise_usage, "Row-scaled NVFP4 quantization requires rowwise usage."); + NVTE_CHECK(!columnwise_usage, + "Row-scaled NVFP4 quantization does not support columnwise usage."); + } + tensor.attr("_row_scaled_nvfp4") = py::cast(row_scaled_nvfp4); // Coerce row-wise data if (rowwise_usage) { @@ -2048,11 +2082,12 @@ std::pair NVFP4Quantizer::convert_and_update_tensor( rowwise_scale_inv = at::empty(scale_inv_shape_int64, opts); tensor.attr("_rowwise_scale_inv") = *rowwise_scale_inv; } - if (!amax_rowwise) { + const int64_t amax_rows = row_scaled_nvfp4 ? static_cast(flat_first_dim) : 1; + if (!amax_rowwise || amax_rowwise->numel() != amax_rows) { const auto opts = at::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA); // hadamard amax kernel will zero out pointer with ZeroAmaxKernel // nvte_compute_amax_with_config will zero out the pointer if needed - amax_rowwise = at::empty({1}, opts); + amax_rowwise = at::empty({amax_rows}, opts); tensor.attr("_amax_rowwise") = *amax_rowwise; } } else { // rowwise_usage == false @@ -2118,7 +2153,7 @@ std::pair NVFP4Quantizer::convert_and_update_tensor( out_cpp.set_rowwise_data(rowwise_data->data_ptr(), DType::kFloat4E2M1, shape); out_cpp.set_rowwise_scale_inv(rowwise_scale_inv->data_ptr(), DType::kFloat8E4M3, getTensorShape(*rowwise_scale_inv)); - out_cpp.set_amax(amax_rowwise->data_ptr(), DType::kFloat32, std::vector{1}); + out_cpp.set_amax(amax_rowwise->data_ptr(), DType::kFloat32, getTensorShape(*amax_rowwise)); } if (columnwise_usage) { // enforce 2D shape to avoid [S, B, H] shape and B and be 1 @@ -2133,6 +2168,7 @@ std::pair NVFP4Quantizer::convert_and_update_tensor( std::vector{1}); } out_cpp.set_with_gemm_swizzled_scales(with_gemm_swizzled_scales); + out_cpp.set_row_scaled_nvfp4(row_scaled_nvfp4); this->set_quantization_params(&out_cpp); return {std::move(out_cpp), std::move(tensor)}; @@ -2241,6 +2277,18 @@ void NVFP4Quantizer::quantize_impl(const TensorWrapper& input, TensorWrapper& ou } size_t cols = input.size(input.ndim() - 1); + const bool row_scaled_nvfp4 = out.get_row_scaled_nvfp4(); + if (row_scaled_nvfp4) { + NVTE_CHECK(!this->with_rht, "Row-scaled NVFP4 quantization does not support RHT."); + NVTE_CHECK(!this->with_2d_quantization, + "Row-scaled NVFP4 quantization does not support 2D quantization."); + NVTE_CHECK(!this->stochastic_rounding, + "Row-scaled NVFP4 quantization does not support stochastic rounding."); + NVTE_CHECK(!this->with_amax_reduction, + "Row-scaled NVFP4 quantization does not support amax reduction."); + NVTE_CHECK(cols % 16 == 0, "Row-scaled NVFP4 quantization requires last dim divisible by 16."); + } + // Restriction for the RHT cast fusion kernel because we are using MMA hardware for computing RHT bool eligible_for_rht_cast_fusion = input.dtype() == DType::kBFloat16 && rows % 64 == 0 && cols % 128 == 0; @@ -2307,7 +2355,7 @@ void NVFP4Quantizer::quantize_impl(const TensorWrapper& input, TensorWrapper& ou "Use with_post_rht_amax=true instead."); } } else { // Without RHT - if (compute_amax) { + if (compute_amax && !row_scaled_nvfp4) { // Amax pointers auto rowwise_amax_ptr = out.get_amax().data_ptr; auto columnwise_amax_ptr = out.get_columnwise_amax().data_ptr; @@ -2408,6 +2456,8 @@ void NVFP4Quantizer::quantize(const TensorWrapper& input, TensorWrapper& out, } void NVFP4Quantizer::quantize_with_amax(TensorWrapper& input, TensorWrapper& out) { + NVTE_CHECK(!out.get_row_scaled_nvfp4(), + "quantize_with_amax is not supported for row-scaled NVFP4 quantization."); // Update output tensor amaxes with input tensor amax auto input_amax_ptr = input.amax(); auto output_rowwise_amax_ptr = out.get_amax().data_ptr; diff --git a/transformer_engine/pytorch/csrc/type_converters.cpp b/transformer_engine/pytorch/csrc/type_converters.cpp index e13554a98c..37ab0b0535 100644 --- a/transformer_engine/pytorch/csrc/type_converters.cpp +++ b/transformer_engine/pytorch/csrc/type_converters.cpp @@ -134,6 +134,7 @@ TensorWrapper NVTETensorFromNVFP4Tensor(py::handle tensor, Quantizer *quantizer) const bool rowwise_usage = !(tensor.attr("_rowwise_data").is_none()); const bool columnwise_usage = !(tensor.attr("_columnwise_data").is_none()); const bool with_gemm_swizzled_scales = tensor.attr("_with_gemm_swizzled_scales").cast(); + const bool row_scaled_nvfp4 = tensor.attr("_row_scaled_nvfp4").cast(); NVTE_CHECK(rowwise_usage || columnwise_usage, "No data found for NVFP4 Tensor."); @@ -163,6 +164,7 @@ TensorWrapper NVTETensorFromNVFP4Tensor(py::handle tensor, Quantizer *quantizer) // Scale layout ret.set_with_gemm_swizzled_scales(with_gemm_swizzled_scales); + ret.set_row_scaled_nvfp4(row_scaled_nvfp4); // Quantizer state quantizer->set_quantization_params(&ret); diff --git a/transformer_engine/pytorch/custom_recipes/quantization_nvfp4.py b/transformer_engine/pytorch/custom_recipes/quantization_nvfp4.py index dd01ae05d3..12f8ef8f5b 100644 --- a/transformer_engine/pytorch/custom_recipes/quantization_nvfp4.py +++ b/transformer_engine/pytorch/custom_recipes/quantization_nvfp4.py @@ -350,9 +350,17 @@ def __init__( pow_2_scales: bool = False, eps: float = 0.0, quant_tile_shape: Tuple[int, int] = (1, 16), + row_scaled_nvfp4: bool = False, with_rht: bool = False, with_random_sign_mask: bool = True, ): + if row_scaled_nvfp4: + if not rowwise: + raise ValueError("Row-scaled NVFP4 reference quantization requires rowwise usage.") + if columnwise: + raise ValueError( + "Row-scaled NVFP4 reference quantization does not support columnwise usage." + ) super().__init__(rowwise=rowwise, columnwise=columnwise) self.internal = True @@ -360,6 +368,7 @@ def __init__( self.pow_2_scales = pow_2_scales self.eps = eps self.quant_tile_shape = quant_tile_shape + self.row_scaled_nvfp4 = row_scaled_nvfp4 self.with_rht = with_rht self.with_random_sign_mask = with_random_sign_mask @@ -447,6 +456,7 @@ def _quantize_blockwise_reference( tile_len_y: int, *, pow_2_scales: bool, + row_scaled_nvfp4: bool = False, eps: float, # pylint: disable=unused-argument ) -> Tuple[torch.Tensor, torch.Tensor]: @@ -488,6 +498,9 @@ def _quantize_blockwise_reference( decode_scale.to(torch.float32), ) else: + if row_scaled_nvfp4: + global_amax = global_amax.to(torch.float32).view(m, 1, 1) + global_encode_scale = torch.div(FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX, global_amax) global_encode_scale = torch.min( global_encode_scale, @@ -497,8 +510,15 @@ def _quantize_blockwise_reference( dtype=torch.float32, ), ) - if global_encode_scale == torch.tensor(0.0, device=x.device, dtype=torch.float32): - global_encode_scale = torch.tensor(1.0, device=x.device, dtype=torch.float32) + if global_encode_scale.numel() == 1: + if global_encode_scale == torch.tensor(0.0, device=x.device, dtype=torch.float32): + global_encode_scale = torch.tensor(1.0, device=x.device, dtype=torch.float32) + else: + global_encode_scale = torch.where( + global_encode_scale == 0.0, + torch.ones_like(global_encode_scale), + global_encode_scale, + ) global_decode_scale = torch.div(1.0, global_encode_scale) global_encode_scale_multiplier = global_encode_scale * torch.reciprocal(FLOAT4_E2M1_MAX) @@ -609,6 +629,8 @@ def _quantize(self, tensor: torch.Tensor) -> Tuple[ raise ValueError( f"MXFP4 only supports 1x32 tile shape, got {self.quant_tile_shape}" ) + if self.row_scaled_nvfp4: + raise ValueError("Row-scaled NVFP4 is only supported for NVFP4 (non-pow2) mode.") # TODO(etsykunov): Fix bug where global_amax_row and # global_amax_col are not defined # global_amax = torch.empty(0, device=tensor.device, dtype=torch.float32) @@ -625,13 +647,22 @@ def _quantize(self, tensor: torch.Tensor) -> Tuple[ if self.with_rht else tensor.t().contiguous() ) - # Compute amax for rowwise and columnwise paths separately - global_amax_row = torch.max(torch.abs(row_input)).to(torch.float32).view(1) - global_amax_col = ( - torch.max(torch.abs(col_input)).to(torch.float32).view(1) - if self.columnwise_usage - else global_amax_row - ) + if self.row_scaled_nvfp4: + if self.quant_tile_shape != (1, 16): + raise ValueError( + "Row-scaled NVFP4 only supports NVFP4 1x16 tile shape, " + f"got {self.quant_tile_shape}" + ) + global_amax_row = torch.max(torch.abs(row_input), dim=1).values.to(torch.float32) + global_amax_col = global_amax_row + else: + # Compute amax for rowwise and columnwise paths separately + global_amax_row = torch.max(torch.abs(row_input)).to(torch.float32).view(1) + global_amax_col = ( + torch.max(torch.abs(col_input)).to(torch.float32).view(1) + if self.columnwise_usage + else global_amax_row + ) transpose_scales = False @@ -648,6 +679,7 @@ def _quantize(self, tensor: torch.Tensor) -> Tuple[ self.quant_tile_shape[1], self.quant_tile_shape[0], pow_2_scales=self.pow_2_scales, + row_scaled_nvfp4=self.row_scaled_nvfp4, eps=self.eps, ) if transpose_scales: @@ -868,7 +900,11 @@ def qgemm( partial_alpha = qresult_x.global_amax_col * qresult_w.global_amax_col else: partial_alpha = qresult_x.global_amax_row * qresult_w.global_amax_row - alpha = torch.div(partial_alpha, factor).squeeze(-1) + if partial_alpha.numel() > 1 and partial_alpha.numel() == high_precision_x.shape[0]: + partial_alpha = partial_alpha.view(-1, 1) + else: + partial_alpha = partial_alpha.squeeze(-1) + alpha = torch.div(partial_alpha, factor) M, K = high_precision_x.shape N, K_w = high_precision_w.shape diff --git a/transformer_engine/pytorch/quantization.py b/transformer_engine/pytorch/quantization.py index 9956fb77ec..e9f009d93d 100644 --- a/transformer_engine/pytorch/quantization.py +++ b/transformer_engine/pytorch/quantization.py @@ -1375,6 +1375,7 @@ def _make_quantizer(idx: int) -> NVFP4Quantizer: with_post_rht_amax=qparams.random_hadamard_transform, with_2d_quantization=qparams.fp4_2d_quantization, stochastic_rounding=qparams.stochastic_rounding, + row_scaled_nvfp4=self.recipe.row_scaled_activation and idx % 3 != 1, ) return [_make_quantizer(idx) for idx in range(self.num_quantizers)] @@ -1389,6 +1390,7 @@ def _make_quantizer(idx: int) -> NVFP4Quantizer: with_post_rht_amax=self.recipe.fp4_quant_bwd_grad.random_hadamard_transform, with_2d_quantization=self.recipe.fp4_quant_bwd_grad.fp4_2d_quantization, stochastic_rounding=self.recipe.fp4_quant_bwd_grad.stochastic_rounding, + row_scaled_nvfp4=False, ) for _ in range(self.num_quantizers) ] diff --git a/transformer_engine/pytorch/tensor/grouped_tensor.py b/transformer_engine/pytorch/tensor/grouped_tensor.py index ab0c7484fc..f28f972b58 100644 --- a/transformer_engine/pytorch/tensor/grouped_tensor.py +++ b/transformer_engine/pytorch/tensor/grouped_tensor.py @@ -92,6 +92,7 @@ def __new__( requires_grad: bool = False, stride: Optional[List[int]] = None, with_gemm_swizzled_scales: bool = False, + row_scaled_nvfp4: bool = False, ): if ( shapes is not None @@ -164,6 +165,7 @@ def __new__( scale_inv_offsets=scale_inv_offsets, columnwise_scale_inv_offsets=columnwise_scale_inv_offsets, with_gemm_swizzled_scales=with_gemm_swizzled_scales, + row_scaled_nvfp4=row_scaled_nvfp4, ) return instance @@ -195,6 +197,7 @@ def copy_grouped_storage_metadata(dst: GroupedTensor, src: GroupedTensor) -> Non dst.logical_shape = src.logical_shape dst.quantized_tensors = src.quantized_tensors dst._with_gemm_swizzled_scales = src._with_gemm_swizzled_scales + dst.row_scaled_nvfp4 = src.row_scaled_nvfp4 def make_wrapper_like(src: GroupedTensor, requires_grad: bool) -> GroupedTensor: """Create a wrapper of the same type and tensor metadata as src.""" diff --git a/transformer_engine/pytorch/tensor/nvfp4_tensor.py b/transformer_engine/pytorch/tensor/nvfp4_tensor.py index 65678aa347..285a7f030a 100644 --- a/transformer_engine/pytorch/tensor/nvfp4_tensor.py +++ b/transformer_engine/pytorch/tensor/nvfp4_tensor.py @@ -128,6 +128,9 @@ class NVFP4Quantizer(Quantizer): """Stochastic rounding, only applicable for gradients.""" stochastic_rounding: bool + """Whether emitted NVFP4 tensors store one FP32 amax per row.""" + row_scaled_nvfp4: bool + """RHT matrix random sign mask""" rht_matrix_random_sign_mask_t: int rht_matrix: torch.Tensor @@ -143,6 +146,7 @@ def __init__( with_post_rht_amax: bool = False, with_2d_quantization: bool = False, stochastic_rounding: bool = False, + row_scaled_nvfp4: bool = False, with_random_sign_mask: bool = True, ) -> None: super().__init__(rowwise=rowwise, columnwise=columnwise) @@ -153,6 +157,7 @@ def __init__( self.amax_reduction_group = amax_reduction_group self.with_2d_quantization = with_2d_quantization self.stochastic_rounding = stochastic_rounding + self.row_scaled_nvfp4 = row_scaled_nvfp4 self.rht_matrix_random_sign_mask_t = get_random_sign_mask_for_rht( with_random_sign_mask, torch.cuda.current_device() ) @@ -198,6 +203,7 @@ def copy(self) -> NVFP4Quantizer: with_post_rht_amax=self.with_post_rht_amax, with_2d_quantization=self.with_2d_quantization, stochastic_rounding=self.stochastic_rounding, + row_scaled_nvfp4=self.row_scaled_nvfp4, ) quantizer.internal = self.internal quantizer.optimize_for_gemm = self.optimize_for_gemm @@ -212,6 +218,8 @@ def quantize_impl(self, tensor: torch.Tensor) -> QuantizedTensor: def is_quantizable(self, inp: torch.Tensor) -> bool: """Returns whether or not given inp can be quantized""" + if self.row_scaled_nvfp4: + return False if inp.ndim < 2: return False if inp.shape[-1] % NVFP4_BLOCK_SCALING_SIZE != 0: @@ -313,6 +321,11 @@ def make_empty( f"Incorrect shape {shape} for NVFP4. Tensor dims must be divisible by" f" {NVFP4_BLOCK_SCALING_SIZE}" ) + if self.row_scaled_nvfp4: + if not self.rowwise_usage: + raise ValueError("Row-scaled NVFP4 quantization requires rowwise usage.") + if self.columnwise_usage: + raise ValueError("Row-scaled NVFP4 quantization does not support columnwise usage.") # Allocate FP4 data data = None @@ -329,8 +342,11 @@ def make_empty( scale_inv = torch.empty( scale_shape, dtype=torch.uint8, device=device, pin_memory=pin_memory ) - # Allocate per tensor scale inverse. FP32 format. - amax_rowwise = torch.zeros(1, dtype=torch.float32, device=device, pin_memory=pin_memory) + # Allocate global amax metadata. Row-scaled NVFP4 stores one value per row. + amax_rows = flat_first_dim if self.row_scaled_nvfp4 else 1 + amax_rowwise = torch.zeros( + amax_rows, dtype=torch.float32, device=device, pin_memory=pin_memory + ) # Allocate FP8 data transpose if needed columnwise_data = None @@ -371,6 +387,7 @@ def make_empty( quantizer=self, requires_grad=requires_grad, with_gemm_swizzled_scales=False, + row_scaled_nvfp4=self.row_scaled_nvfp4, ) def calibrate(self, tensor: torch.Tensor) -> None: @@ -431,6 +448,7 @@ def __new__( fp4_dtype: TE_DType, quantizer: Quantizer, with_gemm_swizzled_scales: bool, + row_scaled_nvfp4: bool = False, **kwargs, ): instance = super().__new__( @@ -445,6 +463,7 @@ def __new__( quantizer, with_gemm_swizzled_scales, *args, + row_scaled_nvfp4=row_scaled_nvfp4, **kwargs, ) return instance diff --git a/transformer_engine/pytorch/tensor/storage/grouped_tensor_storage.py b/transformer_engine/pytorch/tensor/storage/grouped_tensor_storage.py index 485b32328b..ac56d334bc 100644 --- a/transformer_engine/pytorch/tensor/storage/grouped_tensor_storage.py +++ b/transformer_engine/pytorch/tensor/storage/grouped_tensor_storage.py @@ -72,6 +72,7 @@ def _initialize_storage_fields( requires_grad: bool = False, stride: Optional[List[int]] = None, with_gemm_swizzled_scales: bool = False, + row_scaled_nvfp4: bool = False, ) -> None: """ Initialize a GroupedTensor. @@ -147,6 +148,7 @@ def _initialize_storage_fields( # Used as a convenience. instance.quantized_tensors = None instance._with_gemm_swizzled_scales = with_gemm_swizzled_scales + instance.row_scaled_nvfp4 = row_scaled_nvfp4 def __new__( cls, @@ -172,6 +174,7 @@ def __new__( requires_grad: bool = False, stride: Optional[List[int]] = None, with_gemm_swizzled_scales: bool = False, + row_scaled_nvfp4: bool = False, ): instance = object.__new__(cls) cls._initialize_storage_fields( @@ -197,6 +200,7 @@ def __new__( requires_grad=requires_grad, stride=stride, with_gemm_swizzled_scales=with_gemm_swizzled_scales, + row_scaled_nvfp4=row_scaled_nvfp4, ) return instance @@ -371,6 +375,7 @@ def clear(self) -> None: self.columnwise_scale_inv_offsets = None self.tensor_shapes = [] self.fake_dtype = torch.float32 + self.row_scaled_nvfp4 = False def __repr__(self) -> str: """String representation of the GroupedTensorStorage.""" @@ -539,6 +544,7 @@ def copy(self) -> "GroupedTensorStorage": scale_inv_offsets=self.scale_inv_offsets, columnwise_scale_inv_offsets=self.columnwise_scale_inv_offsets, with_gemm_swizzled_scales=self._with_gemm_swizzled_scales, + row_scaled_nvfp4=self.row_scaled_nvfp4, ) @staticmethod @@ -649,6 +655,7 @@ def make_grouped_tensor( scale = None scale_inv_offsets = None columnwise_scale_inv_offsets = None + row_scaled_nvfp4 = False if no_quantization: assert dtype is not None, "dtype must be provided for unquantized GroupedTensor" if rowwise_usage: @@ -707,6 +714,19 @@ def make_grouped_tensor( # Amax buffer for delayed scaling - one per tensor amax = torch.empty(num_tensors, dtype=torch.float32, device=device) elif quantizer._get_compatible_recipe().nvfp4(): + row_scaled_nvfp4 = quantizer.row_scaled_nvfp4 + if row_scaled_nvfp4: + if not rowwise_usage: + raise ValueError( + "Row-scaled NVFP4 grouped quantization requires rowwise usage." + ) + if columnwise_usage: + raise ValueError( + "Row-scaled NVFP4 grouped quantization does not support columnwise usage." + ) + total_amax_elements = ( + sum(math.prod(s[:-1]) for s in shape) if row_scaled_nvfp4 else num_tensors + ) if rowwise_usage: # Allocate rowwise data buffer (1D flattened, uint8, but FP4 packs 2 values per byte) @@ -720,8 +740,7 @@ def make_grouped_tensor( total_scale_elements += math.prod(scale_inv_shape) scale_inv_offsets.append(total_scale_elements) scale_inv = torch.empty(total_scale_elements, dtype=torch.uint8, device=device) - # Amax buffer - one per tensor - amax = torch.empty(num_tensors, dtype=torch.float32, device=device) + amax = torch.empty(total_amax_elements, dtype=torch.float32, device=device) if columnwise_usage: # Allocate columnwise data buffer (1D flattened, uint8, FP4 packed) @@ -738,7 +757,6 @@ def make_grouped_tensor( columnwise_scale_inv = torch.empty( total_columnwise_scale_elements, dtype=torch.uint8, device=device ) - # Columnwise amax buffer - one per tensor columnwise_amax = torch.empty(num_tensors, dtype=torch.float32, device=device) elif quantizer._get_compatible_recipe().float8_block_scaling(): if rowwise_usage: @@ -824,6 +842,7 @@ def make_grouped_tensor( with_gemm_swizzled_scales=( quantizer.optimize_for_gemm if quantizer is not None else False ), + row_scaled_nvfp4=row_scaled_nvfp4, ) grouped_tensor.quantized_tensors = grouped_tensor.split_into_quantized_tensors() return grouped_tensor @@ -936,6 +955,14 @@ def split_into_quantized_tensors( cum += math.prod(scale_shape) columnwise_scale_inv_offsets.append(cum) self.columnwise_scale_inv_offsets = columnwise_scale_inv_offsets + nvfp4_rowwise_amax_offsets = None + row_scaled_nvfp4 = self.row_scaled_nvfp4 + if recipe.nvfp4() and row_scaled_nvfp4: + cum = 0 + nvfp4_rowwise_amax_offsets = [0] + for i in range(self.num_tensors): + cum += math.prod(self.tensor_shapes[i][:-1]) + nvfp4_rowwise_amax_offsets.append(cum) for i in range(self.num_tensors): quantizer = self.quantizer @@ -1128,9 +1155,13 @@ def split_into_quantized_tensors( cscale_shape ) - # Extract amax - one per tensor if self.amax is not None: - amax_rowwise = self.amax[i : i + 1] + if nvfp4_rowwise_amax_offsets is not None: + amax_start = nvfp4_rowwise_amax_offsets[i] + amax_end = nvfp4_rowwise_amax_offsets[i + 1] + amax_rowwise = self.amax[amax_start:amax_end] + else: + amax_rowwise = self.amax[i : i + 1] if self.columnwise_amax is not None: amax_columnwise = self.columnwise_amax[i : i + 1] @@ -1152,6 +1183,7 @@ def split_into_quantized_tensors( fp4_dtype=quantizer.dtype, quantizer=quantizer, with_gemm_swizzled_scales=quantizer.optimize_for_gemm, + row_scaled_nvfp4=row_scaled_nvfp4, ) result.append(tensor) diff --git a/transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py b/transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py index 70699ad71a..e51acb71e5 100644 --- a/transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py +++ b/transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py @@ -97,6 +97,8 @@ class NVFP4TensorStorage(QuantizedTensorStorage): # Whether scaling factors are in the swizzled format expected by # GEMM _with_gemm_swizzled_scales: bool + # Whether this NVFP4 tensor uses row-scaled amax metadata + _row_scaled_nvfp4: bool def __new__( cls, @@ -111,6 +113,7 @@ def __new__( with_gemm_swizzled_scales: bool, *args, fake_dtype: Optional[torch.dtype] = None, + row_scaled_nvfp4: bool = False, **kwargs, ): if cls is NVFP4TensorStorage: @@ -128,6 +131,7 @@ def __new__( instance._amax_rowwise = amax_rowwise instance._amax_columnwise = amax_columnwise instance._with_gemm_swizzled_scales = with_gemm_swizzled_scales + instance._row_scaled_nvfp4 = row_scaled_nvfp4 return instance @@ -152,6 +156,8 @@ def copy_from_storage(self, src: QuantizedTensorStorage) -> None: raise RuntimeError("FP4 dtype mismatch in copy_from_storage") if self._with_gemm_swizzled_scales != src._with_gemm_swizzled_scales: raise RuntimeError("Scale layout mismatch in copy_from_storage") + if self._row_scaled_nvfp4 != src._row_scaled_nvfp4: + raise RuntimeError("Rowwise amax scaling mode mismatch in copy_from_storage") def _copy_optional(dst: Optional[torch.Tensor], src_tensor: Optional[torch.Tensor]): if dst is not None and src_tensor is not None: @@ -176,6 +182,7 @@ def get_metadata(self) -> Dict[str, Any]: "fp4_dtype": self._fp4_dtype, "quantizer": self._quantizer, "with_gemm_swizzled_scales": self._with_gemm_swizzled_scales, + "row_scaled_nvfp4": self._row_scaled_nvfp4, "fake_dtype": self._dtype, } @@ -308,6 +315,7 @@ def view(self, shape: torch.Size): quantizer=self._quantizer, fp4_dtype=self._fp4_dtype, with_gemm_swizzled_scales=self._with_gemm_swizzled_scales, + row_scaled_nvfp4=self._row_scaled_nvfp4, fake_dtype=self._dtype, )