Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
14f77be
Adapt initial implementation and make quantization bitwise exact
zianglih Apr 26, 2026
700cbce
Add col
zianglih Apr 26, 2026
cfd13bb
Add fp32
zianglih Apr 26, 2026
866d337
Clean up tests
zianglih Apr 26, 2026
5a6ea13
Clean up ref
zianglih Apr 26, 2026
ee0aafb
Clean up gemm wrapper
zianglih Apr 27, 2026
e852804
Clean up test
zianglih Apr 27, 2026
9dbb3ad
Clean up
zianglih Apr 27, 2026
475de8a
Rename and reformat
zianglih Apr 27, 2026
62a1c1e
Avoid partial amax folding in gemm
zianglih Apr 27, 2026
44e4e0f
Expand test coverage
zianglih Apr 27, 2026
4755f09
Expand more tests
zianglih Apr 27, 2026
55286ed
Turn on test for grouped linear sanity
zianglih Apr 27, 2026
e4829b8
Rename pertoken to per_token
zianglih Apr 27, 2026
dbbdecb
Expand .cu test
zianglih Apr 27, 2026
2374a6e
Format after rebase
zianglih May 2, 2026
5798285
Fix test after rebase
zianglih May 2, 2026
233bb44
Clean up cpp test
zianglih May 2, 2026
47c9cde
Extend cpp dequantize test
zianglih May 2, 2026
21a19f5
Only pass `per_token_activation` to forward activation quantizer and …
zianglih May 3, 2026
75c19d0
Minor fix test
zianglih May 3, 2026
a3e8305
Improve accuracy by unfolding weight per-tensor fp32
zianglih May 4, 2026
027cb79
Fold row-wise quantization
zianglih May 5, 2026
93a06ad
Drop column wise
zianglih May 5, 2026
db1c2a6
Clean up
zianglih May 5, 2026
9eb06c7
Clean up
zianglih May 5, 2026
21274d8
Clean up column wise
zianglih May 5, 2026
4cbb43a
Move shared test helpers
zianglih May 5, 2026
d4ab1e7
Minor clean up test
zianglih May 5, 2026
363335b
Readability
zianglih May 5, 2026
1a4d3b0
Rename
zianglih May 6, 2026
66622e8
Further refactor
zianglih May 6, 2026
94b05e3
Clean up bias
zianglih May 6, 2026
6c10ed2
Clean up cast
zianglih May 6, 2026
aa519d1
Avoid silently disable column wise
zianglih May 6, 2026
90a97a4
Clean up
zianglih May 6, 2026
600b4cd
`is_quantizable` returns false
zianglih May 6, 2026
cc9a210
Error out grouped gemm
zianglih May 6, 2026
39f96c1
Tighten test
zianglih May 6, 2026
4d34527
Rename verbose rowwise_amax_is_row_scaled
zianglih May 6, 2026
9676563
Clean up
zianglih May 6, 2026
0187d80
Explicitly handle both gemm input and error out
zianglih May 6, 2026
ee74019
Minor
zianglih May 6, 2026
01a32ef
Nits and lint
zianglih May 6, 2026
4e9bef5
Merge branch 'main' into fp4-per-token
zianglih May 6, 2026
afc99ad
Minor fix A100 ci
zianglih May 7, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions docs/envvars.rst
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,12 @@ Kernel Configuration
:Default: ``0``
:Description: Emit a warning when falling back from CUTLASS to cuBLAS for grouped GEMM operations.

.. envvar:: NVTE_NVFP4_ROW_SCALED_ACTIVATION

:Type: ``int`` (0 or 1)
:Default: ``0``
:Description: Enable row-scaled NVFP4 tensors for forward activation quantizers in the ``NVFP4BlockScaling`` recipe. When set to ``1`` (or when ``NVFP4BlockScaling(row_scaled_activation=True)`` is used), rowwise ``amax`` metadata is stored as one FP32 value per tensor row instead of a single scalar.

Torch Compilation and Fusion
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Expand Down
171 changes: 126 additions & 45 deletions tests/cpp/operator/test_cast_nvfp4_transpose.cu
Original file line number Diff line number Diff line change
Expand Up @@ -114,16 +114,14 @@ void quantize_nvfp4_1d(float (*OP)(const float),
block_amax = std::max(block_amax, std::abs(elt));
}

// 2. Compute E4M3 scaling factor
// Compute per-block encoding/decoding scaling factor
const float S_dec_b = block_amax / 6.0f;

// Scale & Store per-block decoding scaling factor
const fp8e4m3 S_dec_b_fp8 = static_cast<fp8e4m3>(S_dec_b * S_enc);
// Compute and store the per-block FP8 decode scale
const float S_dec_b = block_amax * (S_enc * (1.0f / 6.0f));
const fp8e4m3 S_dec_b_fp8 = static_cast<fp8e4m3>(fminf(S_dec_b, Numeric_Traits<float>::maxNorm));
const float S_dec_b_fp32 = static_cast<float>(S_dec_b_fp8);

// Compute "correct" per-block encoding scaling factor
const float S_enc_b_fp8 = S_dec_b_fp32 == 0.f ? 0.f : S_enc / S_dec_b_fp32;
const float S_enc_b_fp8 = S_dec_b_fp32 == 0.f ? 0.f :
fminf(1.0f / (S_dec_b_fp32 * (1.0f / S_enc)), Numeric_Traits<float>::maxNorm);
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have to change here to stay aligned with pytorch reference.


const size_t scale_idx = i * scales_stride + block_X;
scales[scale_idx] = S_dec_b_fp8;
Expand Down Expand Up @@ -317,11 +315,31 @@ void compute_ref(float (*OP)(const float),
const size_t scales_stride,
const size_t scales_stride_t,
const bool use_fast_math,
const bool use_2d_quantization = false)
const bool use_2d_quantization = false,
std::vector<float> *rowwise_amax = nullptr)
{
std::vector<InputType> input_t = create_transpose(input, rows, cols);

if (use_2d_quantization) {
if (rowwise_amax != nullptr) {
rowwise_amax->resize(rows, 0.0f);
for (size_t row = 0; row < rows; ++row) {
float row_amax = 0.0f;
for (size_t col = 0; col < cols; ++col) {
row_amax = fmaxf(row_amax, fabsf(static_cast<float>(input[row * cols + col])));
}
(*rowwise_amax)[row] = row_amax;
quantize_nvfp4(OP,
input + row * cols,
output + row * (cols / 2),
scales + row * scales_stride,
1,
cols,
scales_stride,
row_amax,
use_fast_math,
use_2d_quantization);
}
} else if (use_2d_quantization) {
// Step 1: Compute mathematical 8×8 scaling factors
std::vector<std::vector<fp8e4m3>> math_scales;
compute_2d_mathematical_scales(OP, input, rows, cols, global_amax, math_scales, use_fast_math);
Expand Down Expand Up @@ -504,13 +522,12 @@ void print_detailed_tensor_comparison(const std::string& name,

void compareResults_nvfp4(const Tensor &test,
const void *ref, const void *ref_t, const int rows, const int cols,
double atol = 1e-5, double rtol = 1e-8, bool if_on_gpus = true, bool dump_data = false) {
double atol = 1e-5, double rtol = 1e-8, bool if_on_gpus = true,
bool dump_data = false, bool compare_columnwise = true) {
if (if_on_gpus) test.to_cpu();

const fp4e2m1 *test_data = test.rowwise_cpu_dptr<fp4e2m1>();
const fp4e2m1 *test_data_t = test.columnwise_cpu_dptr<fp4e2m1>();
const fp4e2m1 *ref_data = reinterpret_cast<const fp4e2m1*>(ref);
const fp4e2m1 *ref_data_t = reinterpret_cast<const fp4e2m1*>(ref_t);

// Print detailed element-by-element comparison
// print_detailed_tensor_comparison("output", test_data, ref_data, rows, cols);
Expand All @@ -519,17 +536,33 @@ void compareResults_nvfp4(const Tensor &test,
// Optionally dump tensor data to files for detailed analysis
if (dump_data) {
dump_nvfp4_tensor_data("output", test_data, ref_data, rows, cols);
dump_nvfp4_tensor_data("output_t", test_data_t, ref_data_t, cols, rows);
}

compare_nvfp4_tensors("output", test_data, ref_data, rows, cols, atol, rtol);
compare_nvfp4_tensors("output_t", test_data_t, ref_data_t, cols, rows, atol, rtol);
if (compare_columnwise) {
const fp4e2m1 *test_data_t = test.columnwise_cpu_dptr<fp4e2m1>();
const fp4e2m1 *ref_data_t = reinterpret_cast<const fp4e2m1*>(ref_t);
if (dump_data) {
dump_nvfp4_tensor_data("output_t", test_data_t, ref_data_t, cols, rows);
}
compare_nvfp4_tensors("output_t", test_data_t, ref_data_t, cols, rows, atol, rtol);
}
}

void compare_rowwise_amax(const Tensor &output, const std::vector<float> &ref_amax) {
const std::vector<float> test_amax_data = output.tensor_amax_values();
ASSERT_EQ(test_amax_data.size(), ref_amax.size());
for (size_t row = 0; row < ref_amax.size(); ++row) {
ASSERT_EQ(test_amax_data[row], ref_amax[row])
<< "Row-scaled amax mismatch at row " << row;
}
}

template <typename InputType>
void performTest(float (*OP)(const float),
const std::vector<size_t>& shape,
const bool use_fast_math) {
const bool use_fast_math,
const bool row_scaled_nvfp4 = false) {
using namespace test;

DType itype = TypeInfo<InputType>::dtype;
Expand All @@ -556,7 +589,7 @@ void performTest(float (*OP)(const float),
const size_t scales_stride_t = blocks_X_t;

Tensor input("input", shape, itype);
Tensor output("output", shape, otype, true, true, NVTE_NVFP4_1D_SCALING);
Tensor output("output", shape, otype, true, !row_scaled_nvfp4, NVTE_NVFP4_1D_SCALING);

std::unique_ptr<fp4e2m1x2[]> ref_output = std::make_unique<fp4e2m1x2[]>(rows * (cols / 2));
std::unique_ptr<fp4e2m1x2[]> ref_output_t = std::make_unique<fp4e2m1x2[]>(cols * (rows / 2));
Expand All @@ -567,26 +600,44 @@ void performTest(float (*OP)(const float),

// Golden value of amax chosen to make the 2nd-stage scaling mantissa zero and avoid rounding issues
const float amax = 448.0f * 6.0f * 8.0f;

// Set 2nd stage NVFP4 scaling factor
output.set_tensor_amax(amax);
output.set_tensor_amax_columnwise(amax);

std::vector<float> ref_rowwise_amax;
bool use_2d_quantization = false;
if (row_scaled_nvfp4) {
output.set_tensor_amax_shape({rows});
output.set_row_scaled_nvfp4(true);
compute_ref<InputType>(OP,
input.rowwise_cpu_dptr<InputType>(),
ref_output.get(),
ref_output_t.get(),
ref_scales.get(),
ref_scales_t.get(),
0.0f,
rows,
cols,
scales_stride,
scales_stride_t,
use_fast_math,
use_2d_quantization,
&ref_rowwise_amax);
} else {
// Set 2nd stage NVFP4 scaling factor
output.set_tensor_amax(amax);
output.set_tensor_amax_columnwise(amax);
compute_ref<InputType>(OP,
input.rowwise_cpu_dptr<InputType>(),
ref_output.get(),
ref_output_t.get(),
ref_scales.get(),
ref_scales_t.get(),
amax,
rows,
cols,
scales_stride,
scales_stride_t,
use_fast_math,
use_2d_quantization);
}

compute_ref<InputType>(OP,
input.rowwise_cpu_dptr<InputType>(),
ref_output.get(),
ref_output_t.get(),
ref_scales.get(),
ref_scales_t.get(),
amax,
rows,
cols,
scales_stride,
scales_stride_t,
use_fast_math,
use_2d_quantization);
// Initialize stochastic rounding
Tensor rng_state("rng_state", std::vector<size_t>{2}, DType::kInt64);
rng_state.rowwise_cpu_dptr<int64_t>()[0] = 123; // rng_seed
Expand Down Expand Up @@ -629,23 +680,25 @@ void performTest(float (*OP)(const float),
const double rtol = 1.0E-6;

// Set dump_data=true to enable dumping tensor data to files for analysis
compareResults_nvfp4(output, ref_output.get(), ref_output_t.get(), rows, cols, atol, rtol, true, false);

const fp8e4m3* kernel_scales = output.rowwise_cpu_scale_inv_ptr<fp8e4m3>();
const fp8e4m3* ref_scales_ptr = ref_scales.get();
const fp8e4m3* kernel_scales_t = output.columnwise_cpu_scale_inv_ptr<fp8e4m3>();
const fp8e4m3* ref_scales_t_ptr = ref_scales_t.get();
compareResults_nvfp4(output, ref_output.get(), ref_output_t.get(), rows, cols, atol, rtol, true,
false, !row_scaled_nvfp4);

size_t scale_mismatches_num = 0;
compare_scaling_factors<fp8e4m3>("scales", output.rowwise_cpu_scale_inv_ptr<fp8e4m3>(),
ref_scales.get(),
unpadded_blocks_Y, unpadded_blocks_X, scales_stride,
scale_mismatches_num);

compare_scaling_factors<fp8e4m3>("scales_t", output.columnwise_cpu_scale_inv_ptr<fp8e4m3>(),
ref_scales_t.get(),
unpadded_blocks_Y_t, unpadded_blocks_X_t, scales_stride_t,
scale_mismatches_num);
if (!row_scaled_nvfp4) {
compare_scaling_factors<fp8e4m3>("scales_t", output.columnwise_cpu_scale_inv_ptr<fp8e4m3>(),
ref_scales_t.get(),
unpadded_blocks_Y_t, unpadded_blocks_X_t, scales_stride_t,
scale_mismatches_num);
}

if (row_scaled_nvfp4) {
compare_rowwise_amax(output, ref_rowwise_amax);
}
}

std::vector<std::vector<size_t>> tensor_dims = {
Expand Down Expand Up @@ -678,6 +731,7 @@ class FusedCastTransposeNVFP4TestSuite : public ::testing::TestWithParam
<std::tuple<ActivationType,
std::vector<size_t>,
transformer_engine::DType,
bool,
bool>> {};

TEST_P(FusedCastTransposeNVFP4TestSuite, TestFusedCastTransposeNVFP4) {
Expand All @@ -693,6 +747,7 @@ TEST_P(FusedCastTransposeNVFP4TestSuite, TestFusedCastTransposeNVFP4) {
const auto tensor_dims = std::get<1>(GetParam());
const DType input_type = std::get<2>(GetParam());
const bool use_fast_math = std::get<3>(GetParam());
const bool row_scaled_nvfp4 = std::get<4>(GetParam());

// Skip tests if the input tensor is 1D
if (tensor_dims.size() < 2) {
Expand All @@ -710,7 +765,7 @@ TEST_P(FusedCastTransposeNVFP4TestSuite, TestFusedCastTransposeNVFP4) {
}

TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(input_type, InputType,
performTest<InputType>(OP, tensor_dims, use_fast_math);
performTest<InputType>(OP, tensor_dims, use_fast_math, row_scaled_nvfp4);
);
}

Expand All @@ -733,6 +788,7 @@ INSTANTIATE_TEST_SUITE_P(
::testing::ValuesIn(Activation_types),
::testing::ValuesIn(tensor_dims),
::testing::Values(DType::kBFloat16),
::testing::Values(false),
::testing::Values(false)),
[](const testing::TestParamInfo<FusedCastTransposeNVFP4TestSuite::ParamType>& info) {
std::string name = to_string(std::get<0>(info.param));
Expand All @@ -746,3 +802,28 @@ INSTANTIATE_TEST_SUITE_P(
}
return name;
});

INSTANTIATE_TEST_SUITE_P(
OperatorTestRowScaled,
FusedCastTransposeNVFP4TestSuite,
::testing::Combine(
::testing::Values(ActivationType::Identity),
::testing::Values(tensor_dims[4], tensor_dims[9], tensor_dims[12]),
::testing::Values(DType::kBFloat16, DType::kFloat32),
::testing::Values(false),
::testing::Values(true)),
[](const testing::TestParamInfo<FusedCastTransposeNVFP4TestSuite::ParamType>& info) {
std::string name = to_string(std::get<0>(info.param));
const auto& shape = std::get<1>(info.param);
for (const auto& s: shape) {
name += "X" + std::to_string(s);
}
name += "X" + test::typeName(std::get<2>(info.param));
if (std::get<3>(info.param)) {
name += "X_FAST_SCALING";
}
if (std::get<4>(info.param)) {
name += "XROW_SCALED";
}
return name;
});
Loading