diff --git a/tests/cpp/operator/test_act.cu b/tests/cpp/operator/test_act.cu index ca5ccdc4ce..6edc6bd63b 100644 --- a/tests/cpp/operator/test_act.cu +++ b/tests/cpp/operator/test_act.cu @@ -124,6 +124,7 @@ void performTest(const size_t N, const size_t H) { fillUniform(&input); fillUniform(&ograd); setRandomScale(&output); + const float ref_scale = isFp8Type(otype) ? output.scale() : 1.0f; std::unique_ptr ref_output = std::make_unique(N*H); std::unique_ptr ref_igrad = std::make_unique(N*H); @@ -132,7 +133,7 @@ void performTest(const size_t N, const size_t H) { float ref_amax; compute_ref_act_cast(input.rowwise_cpu_dptr(), ref_output.get(), - output.scale(), &ref_amax, N, H); + ref_scale, &ref_amax, N, H); cudaDeviceSynchronize(); auto err = cudaGetLastError(); @@ -179,6 +180,7 @@ void performTestGLU(const size_t N, const size_t H) { fillUniform(&input); fillUniform(&ograd); setRandomScale(&output); + const float ref_scale = isFp8Type(otype) ? output.scale() : 1.0f; std::unique_ptr ref_output = std::make_unique(N * H); std::unique_ptr ref_igrad = std::make_unique(2 * N * H); @@ -187,7 +189,7 @@ void performTestGLU(const size_t N, const size_t H) { float ref_amax; compute_ref_glu_act_cast(input.rowwise_cpu_dptr(), ref_output.get(), - output.scale(), &ref_amax, N, H); + ref_scale, &ref_amax, N, H); cudaDeviceSynchronize(); auto err = cudaGetLastError(); @@ -197,8 +199,8 @@ void performTestGLU(const size_t N, const size_t H) { auto [atol, rtol] = getTolerances(DType::kFloat32); compareResults("amax", output.amax(), ref_amax, atol, rtol); if (output.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING) { - const float ref_scale = 1.f / output.scale(); - compareResults("scale_inv", *output.rowwise_cpu_scale_inv_ptr(), ref_scale, atol, rtol); + const float ref_scale_inv = 1.f / ref_scale; + compareResults("scale_inv", *output.rowwise_cpu_scale_inv_ptr(), ref_scale_inv, atol, rtol); } } auto [atol, rtol] = getTolerances(otype); diff --git a/tests/cpp/operator/test_cast.cu b/tests/cpp/operator/test_cast.cu index 35d9dd2efd..e8f48feef8 100644 --- a/tests/cpp/operator/test_cast.cu +++ b/tests/cpp/operator/test_cast.cu @@ -53,13 +53,14 @@ void performTest(const std::vector& shape) { fillUniform(&input); setRandomScale(&output_c); + const float ref_scale = isFp8Type(otype) ? output_c.scale() : 1.0f; nvte_quantize(input.data(), output_c.data(), 0); float ref_amax; compute_ref(input.rowwise_cpu_dptr(), ref_output_c.get(), - full_size, &ref_amax, output_c.scale()); + full_size, &ref_amax, ref_scale); cudaDeviceSynchronize(); auto err = cudaGetLastError(); @@ -67,7 +68,7 @@ void performTest(const std::vector& shape) { if (isFp8Type(otype)) { auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32); compareResults("amax", output_c.amax(), ref_amax, atol_amax, rtol_amax); - float ref_scale_inv = 1.f / output_c.scale(); + float ref_scale_inv = 1.f / ref_scale; compareResults("scale_inv", output_c.rowwise_scale_inv(), ref_scale_inv, atol_amax, rtol_amax); } auto [atol, rtol] = getTolerances(otype); diff --git a/tests/cpp/operator/test_cast_current_scaling.cu b/tests/cpp/operator/test_cast_current_scaling.cu index 4dd6cd2d58..7cca0d72e0 100644 --- a/tests/cpp/operator/test_cast_current_scaling.cu +++ b/tests/cpp/operator/test_cast_current_scaling.cu @@ -123,6 +123,7 @@ void performTest(const std::vector& shape) { nvte_compute_amax(input.data(), output_c.data(), 0); QuantizationConfigWrapper config; nvte_compute_scale_from_amax(output_c.data(), config, 0); + // avoid atomic amax update in cuda cast kernels because of current per-tensor scaling amax_to_check = output_c.amax(); output_c.set_tensor_amax_nullptr(); @@ -130,7 +131,7 @@ void performTest(const std::vector& shape) { nvte_quantize(input.data(), output_c.data(), 0); float ref_amax; - float ref_scale; + float ref_scale = 1.0; float ref_scale_inv; if (is_out_fp8){ compute_amax_scale_ref(input.rowwise_cpu_dptr(), @@ -138,13 +139,13 @@ void performTest(const std::vector& shape) { } compute_ref(input.rowwise_cpu_dptr(), ref_output_c.get(), - full_size, nullptr, is_out_fp8 ? output_c.scale() : 1.0f ); + full_size, nullptr, ref_scale); cudaDeviceSynchronize(); auto err = cudaGetLastError(); ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); - if (isFp8Type(otype)) { + if (is_out_fp8) { auto [atol_fp32, rtol_fp32] = getTolerances(DType::kFloat32); compareResults("amax", amax_to_check, ref_amax, 0.0f, rtol_fp32); compareResults("scale", output_c.scale(), ref_scale, 0.0f, rtol_fp32); diff --git a/tests/cpp/operator/test_cast_dbias.cu b/tests/cpp/operator/test_cast_dbias.cu index 18f07153c6..b7b5db48c3 100644 --- a/tests/cpp/operator/test_cast_dbias.cu +++ b/tests/cpp/operator/test_cast_dbias.cu @@ -74,13 +74,14 @@ void performTest(const std::vector& shape) { fillUniform(&input); setRandomScale(&output_c); + const float ref_scale = isFp8Type(otype) ? output_c.scale() : 1.0f; std::unique_ptr ref_output_c = std::make_unique(N*H); std::unique_ptr ref_output_dbias = std::make_unique(H); CType ref_amax; compute_ref_cast_dbias(input.rowwise_cpu_dptr(), - output_c.scale(), + ref_scale, ref_output_c.get(), &ref_amax, ref_output_dbias.get(), @@ -109,7 +110,7 @@ void performTest(const std::vector& shape) { if (isFp8Type(otype)) { auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32); compareResults("amax", output_c.amax(), ref_amax, atol_amax, rtol_amax); - float ref_scale_inv = 1.f / output_c.scale(); + float ref_scale_inv = 1.f / ref_scale; compareResults("scale_inv", output_c.rowwise_scale_inv(), ref_scale_inv, atol_amax, rtol_amax); } auto [atol, rtol] = getTolerances(otype); diff --git a/tests/cpp/operator/test_cast_dbias_dgelu.cu b/tests/cpp/operator/test_cast_dbias_dgelu.cu index 8213e5665a..d8b8a20e6f 100644 --- a/tests/cpp/operator/test_cast_dbias_dgelu.cu +++ b/tests/cpp/operator/test_cast_dbias_dgelu.cu @@ -84,6 +84,7 @@ void performTest(const std::vector& shape) { fillUniform(&input); fillUniform(&grad); setRandomScale(&output_c); + const float ref_scale = isFp8Type(otype) ? output_c.scale() : 1.0f; std::unique_ptr ref_output_c = std::make_unique(N*H); std::unique_ptr ref_output_dbias = std::make_unique(H); @@ -91,7 +92,7 @@ void performTest(const std::vector& shape) { CType ref_amax; compute_ref_cast_dbias_dgelu(input.rowwise_cpu_dptr(), grad.rowwise_cpu_dptr(), - output_c.scale(), + ref_scale, ref_output_c.get(), &ref_amax, ref_output_dbias.get(), @@ -123,7 +124,7 @@ void performTest(const std::vector& shape) { if (isFp8Type(otype)) { auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32); compareResults("amax", output_c.amax(), ref_amax, atol_amax, rtol_amax); - float ref_scale_inv = 1.f / output_c.scale(); + float ref_scale_inv = 1.f / ref_scale; compareResults("scale_inv", output_c.rowwise_scale_inv(), ref_scale_inv, atol_amax, rtol_amax); } diff --git a/tests/cpp/operator/test_cast_gated_swiglu.cu b/tests/cpp/operator/test_cast_gated_swiglu.cu index 298b978f2a..5298cc7577 100644 --- a/tests/cpp/operator/test_cast_gated_swiglu.cu +++ b/tests/cpp/operator/test_cast_gated_swiglu.cu @@ -79,6 +79,7 @@ void performTest(const std::vector& shape) { fillUniform(&grad); fillUniform(&input); setRandomScale(&output_c); + const float ref_scale = isFp8Type(otype) ? output_c.scale() : 1.0f; std::unique_ptr ref_output_c = std::make_unique(input_size); @@ -91,7 +92,7 @@ void performTest(const std::vector& shape) { float ref_amax; compute_ref_cast_dgated_swiglu(grad.rowwise_cpu_dptr(), input.rowwise_cpu_dptr(), - output_c.scale(), + ref_scale, ref_output_c.get(), &ref_amax, rows, @@ -100,7 +101,7 @@ void performTest(const std::vector& shape) { if (isFp8Type(otype)) { auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32); compareResults("amax", output_c.amax(), ref_amax, atol_amax, rtol_amax); - float ref_scale_inv = 1.f / output_c.scale(); + float ref_scale_inv = 1.f / ref_scale; compareResults("scale_inv", output_c.rowwise_scale_inv(), ref_scale_inv, atol_amax, rtol_amax); } diff --git a/tests/cpp/operator/test_cast_nvfp4_transpose.cu b/tests/cpp/operator/test_cast_nvfp4_transpose.cu index 15d7c695c9..4ab06abcde 100644 --- a/tests/cpp/operator/test_cast_nvfp4_transpose.cu +++ b/tests/cpp/operator/test_cast_nvfp4_transpose.cu @@ -502,7 +502,7 @@ void print_detailed_tensor_comparison(const std::string& name, printf("==================================\n"); } -void compareResults_nvfp4(const Tensor &test, +void compareResults_nvfp4(Tensor &test, const void *ref, const void *ref_t, const int rows, const int cols, double atol = 1e-5, double rtol = 1e-8, bool if_on_gpus = true, bool dump_data = false) { if (if_on_gpus) test.to_cpu(); diff --git a/tests/cpp/operator/test_cast_transpose.cu b/tests/cpp/operator/test_cast_transpose.cu index 44c78e4a09..9a5dc959da 100644 --- a/tests/cpp/operator/test_cast_transpose.cu +++ b/tests/cpp/operator/test_cast_transpose.cu @@ -55,13 +55,13 @@ void performTest(const size_t N, const size_t H) { fillUniform(&input); setRandomScale(&output); + const float ref_scale = isFp8Type(otype) ? output.scale() : 1.0f; nvte_quantize(input.data(), output.data(), 0); float ref_amax; compute_ref(input.rowwise_cpu_dptr(), ref_output_c.get(), - ref_output_t.get(), N, H, &ref_amax, - output.scale()); + ref_output_t.get(), N, H, &ref_amax, ref_scale); cudaDeviceSynchronize(); auto err = cudaGetLastError(); @@ -69,7 +69,7 @@ void performTest(const size_t N, const size_t H) { if (isFp8Type(otype)) { auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32); compareResults("amax", output.amax(), ref_amax, atol_amax, rtol_amax); - float ref_scale_inv = 1.f / output.scale(); + float ref_scale_inv = 1.f / ref_scale; compareResults("scale_inv", output.rowwise_scale_inv(), ref_scale_inv, atol_amax, rtol_amax); } auto [atol, rtol] = getTolerances(otype); diff --git a/tests/cpp/operator/test_cast_transpose_dbias.cu b/tests/cpp/operator/test_cast_transpose_dbias.cu index 5b06b28327..f9303d34f5 100644 --- a/tests/cpp/operator/test_cast_transpose_dbias.cu +++ b/tests/cpp/operator/test_cast_transpose_dbias.cu @@ -73,6 +73,7 @@ void performTest(const size_t N, const size_t H) { fillUniform(&input); setRandomScale(&output); + const float ref_scale = isFp8Type(otype) ? output.scale() : 1.0f; std::unique_ptr ref_output_c = std::make_unique(N*H); std::unique_ptr ref_output_t = std::make_unique(N*H); @@ -80,7 +81,7 @@ void performTest(const size_t N, const size_t H) { CType ref_amax; compute_ref_cast_transpose_dbias(input.rowwise_cpu_dptr(), - output.scale(), + ref_scale, ref_output_c.get(), ref_output_t.get(), &ref_amax, @@ -111,7 +112,7 @@ void performTest(const size_t N, const size_t H) { if (isFp8Type(otype)) { auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32); compareResults("amax", output.amax(), ref_amax, atol_amax, rtol_amax); - float ref_scale_inv = 1.f / output.scale(); + float ref_scale_inv = 1.f / ref_scale; compareResults("scale_inv", output.rowwise_scale_inv(), ref_scale_inv, atol_amax, rtol_amax); } auto [atol, rtol] = getTolerances(otype); diff --git a/tests/cpp/operator/test_cast_transpose_dbias_dgelu.cu b/tests/cpp/operator/test_cast_transpose_dbias_dgelu.cu index 9a4a2fa080..31eafff80f 100644 --- a/tests/cpp/operator/test_cast_transpose_dbias_dgelu.cu +++ b/tests/cpp/operator/test_cast_transpose_dbias_dgelu.cu @@ -86,6 +86,7 @@ void performTest(const size_t N, const size_t H) { fillUniform(&input); fillUniform(&gelu_input); setRandomScale(&output); + const float ref_scale = isFp8Type(otype) ? output.scale() : 1.0f; std::unique_ptr ref_output_c = std::make_unique(N*H); std::unique_ptr ref_output_t = std::make_unique(N*H); @@ -94,7 +95,7 @@ void performTest(const size_t N, const size_t H) { CType ref_amax; compute_ref_cast_transpose_dbias_dgelu(input.rowwise_cpu_dptr(), gelu_input.rowwise_cpu_dptr(), - output.scale(), + ref_scale, ref_output_c.get(), ref_output_t.get(), &ref_amax, @@ -127,7 +128,7 @@ void performTest(const size_t N, const size_t H) { if (isFp8Type(otype)) { auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32); compareResults("amax", output.amax(), ref_amax, atol_amax, rtol_amax); - float ref_scale_inv = 1.f / output.scale(); + float ref_scale_inv = 1.f / ref_scale; compareResults("scale_inv", output.rowwise_scale_inv(), ref_scale_inv, atol_amax, rtol_amax); } diff --git a/tests/cpp/operator/test_cast_transpose_dgeglu.cu b/tests/cpp/operator/test_cast_transpose_dgeglu.cu index a87c0c5a42..15ecd3ab66 100644 --- a/tests/cpp/operator/test_cast_transpose_dgeglu.cu +++ b/tests/cpp/operator/test_cast_transpose_dgeglu.cu @@ -81,6 +81,7 @@ void performTest(const size_t N, const size_t H) { fillUniform(&grad); fillUniform(&input); setRandomScale(&output); + const float ref_scale = isFp8Type(otype) ? output.scale() : 1.0f; std::unique_ptr ref_output_c = std::make_unique(N * H * 2); std::unique_ptr ref_output_t = std::make_unique(N * H * 2); @@ -89,7 +90,7 @@ void performTest(const size_t N, const size_t H) { CType ref_amax; compute_ref_cast_transpose_dgated_gelu(grad.rowwise_cpu_dptr(), input.rowwise_cpu_dptr(), - output.scale(), ref_output_c.get(), ref_output_t.get(), + ref_scale, ref_output_c.get(), ref_output_t.get(), &ref_amax, N, H); cudaDeviceSynchronize(); @@ -99,7 +100,7 @@ void performTest(const size_t N, const size_t H) { if (isFp8Type(otype)) { auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32); compareResults("amax", output.amax(), ref_amax, atol_amax, rtol_amax); - float ref_scale_inv = 1.f / output.scale(); + float ref_scale_inv = 1.f / ref_scale; compareResults("scale_inv", output.rowwise_scale_inv(), ref_scale_inv, atol_amax, rtol_amax); } diff --git a/tests/cpp/operator/test_dequantize_nvfp4.cu b/tests/cpp/operator/test_dequantize_nvfp4.cu index 96e85cb5ed..20efc943b6 100644 --- a/tests/cpp/operator/test_dequantize_nvfp4.cu +++ b/tests/cpp/operator/test_dequantize_nvfp4.cu @@ -75,7 +75,7 @@ void compute_ref_dequantize_nvfp4(const uint8_t *packed_data, } template -float compute_amax(const test::Tensor &t, size_t rows, size_t cols) { +float compute_amax(test::Tensor &t, size_t rows, size_t cols) { t.to_cpu(); const auto *data = t.rowwise_cpu_dptr(); float amax = 0.0f; diff --git a/tests/cpp/operator/test_multi_cast_transpose.cu b/tests/cpp/operator/test_multi_cast_transpose.cu index 2bb35c4b89..0271c9dc6b 100644 --- a/tests/cpp/operator/test_multi_cast_transpose.cu +++ b/tests/cpp/operator/test_multi_cast_transpose.cu @@ -97,7 +97,7 @@ void performTest() { std::copy(input.rowwise_cpu_dptr(), input.rowwise_cpu_dptr() + height * width, ref_input_list.back().begin()); - ref_scale_list[tensor_id] = output.scale(); + ref_scale_list[tensor_id] = isFp8Type(otype) ? output.scale() : 1.0f; ref_height_list[tensor_id] = height; ref_width_list[tensor_id] = width; } @@ -138,7 +138,7 @@ void performTest() { atol_amax, rtol_amax); compareResults("scale_inv", output_list[tensor_id].rowwise_scale_inv(), - 1.f / output_list[tensor_id].scale(), + 1.f / ref_scale_list[tensor_id], atol_amax, rtol_amax); } auto [atol, rtol] = getTolerances(otype); diff --git a/tests/cpp/operator/test_normalization.cu b/tests/cpp/operator/test_normalization.cu index f737005e26..ea6692dba4 100644 --- a/tests/cpp/operator/test_normalization.cu +++ b/tests/cpp/operator/test_normalization.cu @@ -208,7 +208,7 @@ void performTest(const size_t N, const size_t H, const bool zero_centered_gamma, auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32); if (isFp8Type(otype)) { compareResults("amax", z.amax(), ref_amax, atol_amax, rtol_amax); - float ref_scale_inv = 1.f / z.scale(); + float ref_scale_inv = 1.f / ref_scale; compareResults("scale_inv", z.rowwise_scale_inv(), ref_scale_inv, atol_amax, rtol_amax); } diff --git a/tests/cpp/operator/test_qdq.cu b/tests/cpp/operator/test_qdq.cu index 4e364fffa4..034280aa9a 100644 --- a/tests/cpp/operator/test_qdq.cu +++ b/tests/cpp/operator/test_qdq.cu @@ -65,12 +65,13 @@ void performTestQ(const size_t N) { fillUniform(&input); setRandomScale(&output); + const float ref_scale = output.scale(); nvte_quantize(input.data(), output.data(), 0); float ref_amax; compute_ref_q(input.rowwise_cpu_dptr(), ref_output.get(), - N, &ref_amax, output.scale()); + N, &ref_amax, ref_scale); cudaDeviceSynchronize(); auto err = cudaGetLastError(); diff --git a/tests/cpp/test_common.cu b/tests/cpp/test_common.cu index c756b83810..2becbc7302 100644 --- a/tests/cpp/test_common.cu +++ b/tests/cpp/test_common.cu @@ -8,12 +8,13 @@ #include "test_common.h" #include +#include +#include +#include +#include #include #include #include -#include -#include -#include #include #include @@ -193,33 +194,6 @@ std::pair get_scales(const NVTEShape& shape, return {ret_rowwise, ret_colwise}; } - if (scaling_mode == NVTE_MXFP8_1D_SCALING) { - std::vector shape_vec; - for (size_t i = 0; i < shape.ndim; ++i) { - shape_vec.push_back(shape.data[i]); - } - size_t first_dim = first_dimension(shape_vec); - size_t last_dim = last_dimension(shape_vec); - - scale_inv_meta ret_rowwise, ret_colwise; - - const size_t block_size_X_rowwise = 32; - size_t scale_dim_Y_rowwise = DIVUP_TO_MULTIPLE(first_dim, scale_tensor_alignment_Y_rowwise); - size_t scale_dim_X_rowwise = DIVUP_TO_MULTIPLE(DIVUP(last_dim, block_size_X_rowwise), scale_tensor_alignment_X_rowwise); - ret_rowwise.shape = {scale_dim_Y_rowwise, scale_dim_X_rowwise}; - - const size_t block_size_Y_colwise = 32; - size_t scale_dim_Y_colwise = DIVUP_TO_MULTIPLE(DIVUP(first_dim, block_size_Y_colwise), scale_tensor_alignment_Y_colwise); - size_t scale_dim_X_colwise = DIVUP_TO_MULTIPLE(last_dim, scale_tensor_alignment_X_colwise); - ret_colwise.shape = {scale_dim_Y_colwise, scale_dim_X_colwise}; - - ret_rowwise.type = DType::kFloat8E8M0; - ret_colwise.type = DType::kFloat8E8M0; - ret_rowwise.type_size_bits = typeToNumBits(DType::kFloat8E8M0); - ret_colwise.type_size_bits = typeToNumBits(DType::kFloat8E8M0); - - return {ret_rowwise, ret_colwise}; - } if (scaling_mode == NVTE_BLOCK_SCALING_2D) { std::vector shape_vec; for (size_t i = 0; i < shape.ndim; ++i) { @@ -276,6 +250,36 @@ std::pair get_scales(const NVTEShape& shape, NVTE_ERROR("Invalid scaling mode!"); } +Tensor::Buffer::Buffer(size_t size, DType dtype) + : size_{size}, dtype_{dtype}, bytes_{size * typeToNumBits(dtype) / 8} { + if (bytes_ > 0) { + cpu_buffer_.reset(new unsigned char[bytes_]); + std::memset(cpu_buffer_.get(), 0, bytes_); + unsigned char *gpu_buffer = nullptr; + NVTE_CHECK_CUDA(cudaMalloc(&gpu_buffer, bytes_)); + gpu_buffer_.reset(gpu_buffer); + NVTE_CHECK_CUDA(cudaMemset(gpu_buffer_.get(), 0, bytes_)); + } +} + +void Tensor::Buffer::to_cpu() { + if (bytes_ > 0) { + NVTE_CHECK_CUDA(cudaMemcpy(cpu_buffer_.get(), gpu_buffer_.get(), bytes_, cudaMemcpyDeviceToHost)); + } +} + +void Tensor::Buffer::from_cpu() { + if (bytes_ > 0) { + NVTE_CHECK_CUDA(cudaMemcpy(gpu_buffer_.get(), cpu_buffer_.get(), bytes_, cudaMemcpyHostToDevice)); + } +} + +void Tensor::Buffer::GPUDeleter::operator() (void *ptr) { + if (ptr != nullptr) { + cudaFree(ptr); + } +} + Tensor::Tensor(const std::string& name, const NVTEShape &shape, const DType type, const bool rowwise, const bool columnwise, @@ -303,31 +307,13 @@ Tensor::Tensor(const std::string& name, flattened_shape = convertShape(flattened_shape_vec); } - // Allocate and initialize data - void *dptr_rowwise = nullptr, *dptr_columnwise = nullptr; - const size_t total_size = bytes(shape, type); - if (total_size != 0) { - if (rowwise) { - cudaMalloc((void**)&dptr_rowwise, total_size); // NOLINT(*) - cudaMemset(dptr_rowwise, 0, total_size); - cpu_data_rowwise_ = std::make_unique(total_size); - std::fill_n(cpu_data_rowwise_.get(), total_size, 0); - } - if (columnwise) { - cudaMalloc((void**)&dptr_columnwise, total_size); // NOLINT(*) - cudaMemset(dptr_columnwise, 0, total_size); - cpu_data_columnwise_ = std::make_unique(total_size); - std::fill_n(cpu_data_columnwise_.get(), total_size, 0); - } - } - - // Set tensor row-wise data + // Allocate row-wise data if (rowwise) { - const DType rowwise_type = (scaling_mode == NVTE_NVFP4_1D_SCALING) ? DType::kFloat4E2M1 : type; - tensor_.set_rowwise_data(dptr_rowwise, rowwise_type, shape); + data_rowwise_ = Tensor::Buffer(product(shape), type); + tensor_.set_rowwise_data(data_rowwise_.gpu_buffer(), type, shape); } - // Set tensor column-wise data + // Allocate column-wise data if (columnwise) { // Determine shape of column-wise data std::vector columnwise_shape_vec; @@ -358,257 +344,208 @@ Tensor::Tensor(const std::string& name, const auto columnwise_shape = nvte_make_shape(columnwise_shape_vec.data(), columnwise_shape_vec.size()); - // Set column-wise data buffer - const DType colwise_type = (scaling_mode == NVTE_NVFP4_1D_SCALING) ? DType::kFloat4E2M1 : type; - tensor_.set_columnwise_data(dptr_columnwise, colwise_type, columnwise_shape); + // Allocate buffer + data_columnwise_ = Tensor::Buffer(product(columnwise_shape), type); + + // Configure TE tensor + tensor_.set_columnwise_data(data_columnwise_.gpu_buffer(), type, columnwise_shape); } - // Configure scales, amaxes, and other tensor buffers - float *amax = nullptr; - float *amax_columnwise = nullptr; - float *scale = nullptr; - float *rowwise_scale_inv = nullptr; - float *columnwise_scale_inv = nullptr; - if (isFp8Type(type) || isFp4Type(type)) { - if (scaling_mode == NVTE_DELAYED_TENSOR_SCALING) { - cudaMalloc((void**)&amax, sizeof(float)); // NOLINT(*) - cudaMemset(amax, 0, sizeof(float)); - cudaMalloc((void**)&scale, sizeof(float)); // NOLINT(*) - cudaMemset(scale, 0, sizeof(float)); - amax_cpu_data_ = std::make_shared(0); - scale_cpu_data_ = std::make_shared(0); - tensor_.set_amax(amax, DType::kFloat32, std::vector{1}); - tensor_.set_scale(scale, DType::kFloat32, std::vector{1}); - cudaMalloc((void**)&rowwise_scale_inv, sizeof(float)); // NOLINT(*) + // Allocate recipe-specific buffers + switch (scaling_mode) { + case NVTE_DELAYED_TENSOR_SCALING: + if (isFp8Type(type)) { + amax_rowwise_ = Tensor::Buffer(1, DType::kFloat32); + scale_ = Tensor::Buffer(1, DType::kFloat32); + scale_inv_rowwise_ = Tensor::Buffer(1, DType::kFloat32); + tensor_.set_amax(amax_rowwise_.gpu_buffer(), DType::kFloat32, std::vector{1}); + tensor_.set_scale(scale_.gpu_buffer(), DType::kFloat32, std::vector{1}); if (rowwise) { - tensor_.set_rowwise_scale_inv(rowwise_scale_inv, DType::kFloat32, - std::vector{1}); - rowwise_scale_inv_cpu_data_ = std::make_unique(sizeof(float)); - std::fill_n(rowwise_scale_inv_cpu_data_.get(), sizeof(float), 0); + tensor_.set_rowwise_scale_inv(scale_inv_rowwise_.gpu_buffer(), DType::kFloat32, std::vector{1}); } if (columnwise) { - tensor_.set_columnwise_scale_inv(rowwise_scale_inv, DType::kFloat32, - std::vector{1}); - columnwise_scale_inv_cpu_data_ = std::make_unique(sizeof(float)); - std::fill_n(columnwise_scale_inv_cpu_data_.get(), sizeof(float), 0); - } - } else { - if (scaling_mode == NVTE_NVFP4_1D_SCALING) { - // Used for NVFP4 second stage scaling - amax_cpu_data_ = std::make_shared(0); - amax_cpu_data_columnwise_ = std::make_shared(0); - cudaMalloc((void**)&amax, sizeof(float)); // NOLINT(*) - cudaMalloc((void**)&amax_columnwise, sizeof(float)); // NOLINT(*) - cudaMemset(amax, 0, sizeof(float)); - cudaMemset(amax_columnwise, 0, sizeof(float)); - tensor_.set_amax(amax, DType::kFloat32, std::vector{1}); - tensor_.set_columnwise_amax(amax_columnwise, DType::kFloat32, std::vector{1}); + tensor_.set_columnwise_scale_inv(scale_inv_rowwise_.gpu_buffer(), DType::kFloat32, std::vector{1}); } + } + break; + case NVTE_MXFP8_1D_SCALING: + case NVTE_BLOCK_SCALING_1D: + case NVTE_BLOCK_SCALING_2D: + case NVTE_NVFP4_1D_SCALING: + { + // Block scaling factors auto [rowwise_scale_meta, colwise_scale_meta] = get_scales(flattened_shape, tensor_.scaling_mode()); - auto rowwise_scale_size = rowwise_scale_meta.bytes(); - auto columnwise_scale_size = colwise_scale_meta.bytes(); - auto scale_shape = rowwise_scale_meta.shape; - auto columnwise_scale_shape = colwise_scale_meta.shape; if (rowwise) { - cudaMalloc((void **)&rowwise_scale_inv, rowwise_scale_size); // NOLINT(*) - cudaMemset(rowwise_scale_inv, 0, rowwise_scale_size); - rowwise_scale_inv_cpu_data_ = std::make_unique(rowwise_scale_size); - std::fill_n(rowwise_scale_inv_cpu_data_.get(), rowwise_scale_size, 0); - auto scale_dtype = rowwise_scale_meta.type; - tensor_.set_rowwise_scale_inv(rowwise_scale_inv, scale_dtype, scale_shape); + const auto scale_shape = rowwise_scale_meta.shape; + const auto scale_dtype = rowwise_scale_meta.type; + scale_inv_rowwise_ = Tensor::Buffer(product(scale_shape), scale_dtype); + tensor_.set_rowwise_scale_inv(scale_inv_rowwise_.gpu_buffer(), scale_dtype, scale_shape); } if (columnwise) { - cudaMalloc((void**)&columnwise_scale_inv, columnwise_scale_size); // NOLINT(*) - cudaMemset(columnwise_scale_inv, 0, columnwise_scale_size); - columnwise_scale_inv_cpu_data_ = std::make_unique(columnwise_scale_size); - std::fill_n(columnwise_scale_inv_cpu_data_.get(), columnwise_scale_size, 0); - auto scale_dtype = colwise_scale_meta.type; - tensor_.set_columnwise_scale_inv(columnwise_scale_inv, scale_dtype, columnwise_scale_shape); + const auto scale_shape = colwise_scale_meta.shape; + const auto scale_dtype = colwise_scale_meta.type; + scale_inv_columnwise_ = Tensor::Buffer(product(scale_shape), scale_dtype); + tensor_.set_columnwise_scale_inv(scale_inv_columnwise_.gpu_buffer(), scale_dtype, scale_shape); } - } - } -} -void Tensor::to_cpu() const { - const NVTEShape s = tensor_.shape(); - const size_t size = bytes(s, tensor_.dtype()); - if (rowwise_) { - cudaMemcpy(cpu_data_rowwise_.get(), - tensor_.get_rowwise_data().data_ptr, - size, - cudaMemcpyDeviceToHost); - } - if (columnwise_) { - const DType colwise_type = tensor_.dtype(); - - const size_t colwise_size = bytes(s, colwise_type); - cudaMemcpy(cpu_data_columnwise_.get(), - tensor_.get_columnwise_data().data_ptr, - colwise_size, - cudaMemcpyDeviceToHost); - } - if (isFp8Type(dtype()) || isFp4Type(dtype())) { - if (tensor_.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING) { - if (tensor_.amax() != nullptr){ - cudaMemcpy(amax_cpu_data_.get(), - tensor_.amax(), - sizeof(float), - cudaMemcpyDeviceToHost); - } - cudaMemcpy(scale_cpu_data_.get(), - tensor_.scale(), - sizeof(float), - cudaMemcpyDeviceToHost); - } else if (tensor_.scaling_mode() == NVTE_NVFP4_1D_SCALING) { - if (rowwise_ && (tensor_.amax() != nullptr)){ - cudaMemcpy(amax_cpu_data_.get(), - tensor_.amax(), - sizeof(float), - cudaMemcpyDeviceToHost); - } - if (columnwise_ && (tensor_.get_columnwise_amax().data_ptr != nullptr)){ - cudaMemcpy(amax_cpu_data_columnwise_.get(), - tensor_.get_columnwise_amax().data_ptr, - sizeof(float), - cudaMemcpyDeviceToHost); + // NVFP4 uses amax for tensor scaling + if (scaling_mode == NVTE_NVFP4_1D_SCALING) { + amax_rowwise_ = Tensor::Buffer(1, DType::kFloat32); + amax_columnwise_ = Tensor::Buffer(1, DType::kFloat32); + tensor_.set_amax(amax_rowwise_.gpu_buffer(), DType::kFloat32, std::vector{1}); + tensor_.set_columnwise_amax(amax_columnwise_.gpu_buffer(), DType::kFloat32, std::vector{1}); } } - auto [rowwise_scale_meta, colwise_scale_meta] = get_scales(s, tensor_.scaling_mode()); - if (rowwise_) { - auto scale_size = rowwise_scale_meta.bytes(); - cudaMemcpy(rowwise_scale_inv_cpu_data_.get(), - tensor_.get_rowwise_scale_inv().data_ptr, - scale_size, - cudaMemcpyDeviceToHost); - } - if (columnwise_) { - auto scale_size = colwise_scale_meta.bytes(); - cudaMemcpy(columnwise_scale_inv_cpu_data_.get(), - tensor_.get_columnwise_scale_inv().data_ptr, - scale_size, - cudaMemcpyDeviceToHost); - } + break; + default: + NVTE_ERROR("Unsupported tensor format (", static_cast(scaling_mode), ")"); } } -void Tensor::from_cpu() const { - const NVTEShape s = tensor_.shape(); - const size_t size = bytes(s, tensor_.dtype()); - if (rowwise_) { - cudaMemcpy(tensor_.get_rowwise_data().data_ptr, cpu_data_rowwise_.get(), size, - cudaMemcpyHostToDevice); - } - if (columnwise_) { - cudaMemcpy(tensor_.get_columnwise_data().data_ptr, cpu_data_columnwise_.get(), size, - cudaMemcpyHostToDevice); - } - if (isFp8Type(dtype()) || isFp4Type(dtype())) { - if (tensor_.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING) { - if (tensor_.amax() != nullptr){ - cudaMemcpy(tensor_.amax(), amax_cpu_data_.get(), sizeof(float), cudaMemcpyHostToDevice); - } - cudaMemcpy(tensor_.scale(), scale_cpu_data_.get(), sizeof(float), cudaMemcpyHostToDevice); - } else if (tensor_.scaling_mode() == NVTE_NVFP4_1D_SCALING) { - if (rowwise_ && (tensor_.amax() != nullptr)) { - cudaMemcpy(tensor_.amax(), amax_cpu_data_.get(), sizeof(float), cudaMemcpyHostToDevice); - } - if (columnwise_ && (tensor_.get_columnwise_amax().data_ptr != nullptr)) { - cudaMemcpy(tensor_.get_columnwise_amax().data_ptr, amax_cpu_data_columnwise_.get(), - sizeof(float), cudaMemcpyHostToDevice); - } - } - auto [rowwise_scale_meta, colwise_scale_meta] = get_scales(s, tensor_.scaling_mode()); - if (rowwise_) { - auto scale_size = rowwise_scale_meta.bytes(); - cudaMemcpy(tensor_.get_rowwise_scale_inv().data_ptr, - rowwise_scale_inv_cpu_data_.get(), scale_size, - cudaMemcpyHostToDevice); - } - if (columnwise_) { - auto scale_size = colwise_scale_meta.bytes(); - cudaMemcpy(tensor_.get_columnwise_scale_inv().data_ptr, - columnwise_scale_inv_cpu_data_.get(), scale_size, - cudaMemcpyHostToDevice); - } - } +void Tensor::to_cpu() { + data_rowwise_.to_cpu(); + data_columnwise_.to_cpu(); + scale_inv_rowwise_.to_cpu(); + scale_inv_columnwise_.to_cpu(); + amax_rowwise_.to_cpu(); + amax_columnwise_.to_cpu(); + scale_.to_cpu(); +} + +void Tensor::from_cpu() { + data_rowwise_.from_cpu(); + data_columnwise_.from_cpu(); + scale_inv_rowwise_.from_cpu(); + scale_inv_columnwise_.from_cpu(); + amax_rowwise_.from_cpu(); + amax_columnwise_.from_cpu(); + scale_.from_cpu(); +} + +void Tensor::set_amax(float amax) { + NVTE_CHECK(amax_rowwise_.size() == 1); + NVTE_CHECK(amax_rowwise_.dtype() == DType::kFloat32); + *amax_rowwise_.cpu_buffer() = amax; + amax_rowwise_.from_cpu(); } void Tensor::set_scale(float scale) { - if (isFp8Type(dtype()) || isFp4Type(dtype())) { - NVTE_CHECK(scale_cpu_data_); - if (tensor_.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING) { - *scale_cpu_data_ = scale; - from_cpu(); - } - } + NVTE_CHECK(scale_.size() == 1); + NVTE_CHECK(scale_.dtype() == DType::kFloat32); + *scale_.cpu_buffer() = scale; + scale_.from_cpu(); } void Tensor::set_scale_inv(float scale_inv) { - if (isFp8Type(dtype()) || isFp4Type(dtype())) { - if (rowwise_) { - NVTE_CHECK(rowwise_scale_inv_cpu_data_); + NVTE_CHECK(scale_inv_rowwise_.size() == 1); + NVTE_CHECK(scale_inv_rowwise_.dtype() == DType::kFloat32); + *scale_inv_rowwise_.cpu_buffer() = scale_inv; + scale_inv_rowwise_.from_cpu(); +} + +void Tensor::set_tensor_amax(float amax) { + set_amax(amax); +} + +void Tensor::set_tensor_amax_columnwise(float amax) { + NVTE_CHECK(amax_columnwise_.size() == 1); + NVTE_CHECK(amax_columnwise_.dtype() == DType::kFloat32); + *amax_columnwise_.cpu_buffer() = amax; + amax_columnwise_.from_cpu(); +} + +void Tensor::fill_uniform_rowwise_scale_inv() { + if (scale_inv_rowwise_.size() == 0) { + return; + } + + // Generate random scales on CPU + const auto numel = scale_inv_rowwise_.size(); + const auto dtype = scale_inv_rowwise_.dtype(); + switch (dtype) { + case DType::kFloat32: + { + auto *cpu_data = scale_inv_rowwise_.cpu_buffer(); + std::uniform_real_distribution dis(-2.0, 1.0); + for (size_t i = 0; i < numel; ++i) { + cpu_data[i] = dis(gen_); + } } - if (columnwise_) { - NVTE_CHECK(columnwise_scale_inv_cpu_data_); + break; + case DType::kFloat8E4M3: + case DType::kFloat8E8M0: + case DType::kByte: + { + auto *cpu_data = reinterpret_cast(scale_inv_rowwise_.cpu_buffer()); + std::uniform_int_distribution dis(0, 127); + for (size_t i = 0; i < numel; ++i) { + cpu_data[i] = dis(gen_); + } } + break; + default: + NVTE_ERROR("Unsupported rowwise scale-inv dtype (", + static_cast(dtype), ")."); + } - auto [rowwise_scale_meta, colwise_scale_meta] = get_scales(tensor_.shape(), tensor_.scaling_mode()); - if (rowwise_) { - auto num_scales = product(rowwise_scale_meta.shape); - if (num_scales == 1) { - rowwise_cpu_scale_inv_ptr()[0] = scale_inv; - } else { - std::uniform_int_distribution dis(0, 127); - auto *scale_inv_ptr = rowwise_cpu_scale_inv_ptr(); - for (size_t i = 0; i < num_scales; i++) { - scale_inv_ptr[i] = dis(gen_); - } + // Update GPU tensor + scale_inv_rowwise_.from_cpu(); +} + +void Tensor::fill_uniform_columnwise_scale_inv() { + if (scale_inv_columnwise_.size() == 0) { + return; + } + + // Generate random scales on CPU + const auto numel = scale_inv_columnwise_.size(); + const auto dtype = scale_inv_columnwise_.dtype(); + switch (dtype) { + case DType::kFloat32: + { + auto *cpu_data = scale_inv_columnwise_.cpu_buffer(); + std::uniform_real_distribution dis(-2.0, 1.0); + for (size_t i = 0; i < numel; ++i) { + cpu_data[i] = dis(gen_); } } - if (columnwise_) { - auto num_scales = product(colwise_scale_meta.shape); - if (num_scales == 1) { - columnwise_cpu_scale_inv_ptr()[0] = scale_inv; - } else { - std::uniform_int_distribution dis(0, 127); - auto *scale_inv_ptr = columnwise_cpu_scale_inv_ptr(); - for (size_t i = 0; i < num_scales; i++) { - scale_inv_ptr[i] = dis(gen_); - } + break; + case DType::kFloat8E4M3: + case DType::kFloat8E8M0: + case DType::kByte: + { + auto *cpu_data = reinterpret_cast(scale_inv_columnwise_.cpu_buffer()); + std::uniform_int_distribution dis(0, 127); + for (size_t i = 0; i < numel; ++i) { + cpu_data[i] = dis(gen_); } } - from_cpu(); + break; + default: + NVTE_ERROR("Unsupported columnwise scale-inv dtype (", + static_cast(dtype), ")."); } + + // Update GPU tensor + scale_inv_columnwise_.from_cpu(); } -void Tensor::shareFP8Meta(const Tensor &other) { - if ((isFp8Type(dtype()) && isFp8Type(other.dtype())) - || isFp4Type(dtype()) && isFp4Type(other.dtype())) { - auto new_tensor = TensorWrapper(other.tensor_.scaling_mode()); - auto my_rowwise_data = tensor_.get_rowwise_data(); - new_tensor.set_rowwise_data(my_rowwise_data.data_ptr, static_cast(my_rowwise_data.dtype), - my_rowwise_data.shape); - auto my_columnwise_data = tensor_.get_columnwise_data(); - new_tensor.set_columnwise_data(my_columnwise_data.data_ptr, - static_cast(my_columnwise_data.dtype), - my_columnwise_data.shape); - auto other_amax = other.tensor_.get_amax(); - new_tensor.set_amax(other_amax.data_ptr, static_cast(other_amax.dtype), - other_amax.shape); - auto other_scale = other.tensor_.get_scale(); - new_tensor.set_scale(other_scale.data_ptr, static_cast(other_scale.dtype), - other_scale.shape); - auto other_row_scale_inv = other.tensor_.get_rowwise_scale_inv(); - new_tensor.set_rowwise_scale_inv(other_row_scale_inv.data_ptr, - static_cast(other_row_scale_inv.dtype), - other_row_scale_inv.shape); - auto other_col_scale_inv = other.tensor_.get_columnwise_scale_inv(); - new_tensor.set_columnwise_scale_inv(other_col_scale_inv.data_ptr, - static_cast(other_col_scale_inv.dtype), - other_col_scale_inv.shape); - tensor_ = std::move(new_tensor); - to_cpu(); +void Tensor::fill_uniform_scale() { + if (scale_.size() == 0) { + return; } + + // Generate random scales on CPU + auto *cpu_data = scale_.cpu_buffer(); + const auto numel = scale_.size(); + NVTE_CHECK(scale_.dtype() == DType::kFloat32); + std::uniform_real_distribution dis(-2.0, 1.0); + for (size_t i = 0; i < numel; ++i) { + cpu_data[i] = dis(gen_); + } + + // Update GPU tensor + scale_.from_cpu(); } using std::to_string; @@ -636,7 +573,7 @@ std::vector unravel(const size_t i, const NVTEShape &shape) { return ret; } -void compareResults_sequential(const std::string &name, const Tensor &test, +void compareResults_sequential(const std::string &name, Tensor &test, const void *ref, const bool rowwise, double atol, double rtol, bool if_on_gpus, const size_t tolerable_mismatches_limit) { @@ -726,7 +663,7 @@ static size_t getFirstMismatchIdx(const DType data_type, const T* test_data, con return first_mismatch_idx; } -void compareResults_parallel(const std::string &name, const Tensor &test, const void *ref, +void compareResults_parallel(const std::string &name, Tensor &test, const void *ref, const bool rowwise, double atol, double rtol, bool if_on_gpus, const size_t tolerable_mismatches_limit) { if (if_on_gpus) test.to_cpu(); @@ -753,7 +690,7 @@ void compareResults_parallel(const std::string &name, const Tensor &test, const ); } -void compareResults(const std::string &name, const Tensor &test, const void *ref, +void compareResults(const std::string &name, Tensor &test, const void *ref, const bool rowwise, double atol, double rtol, bool if_on_gpus, const size_t tolerable_mismatches_limit) { constexpr bool sequential = false; @@ -939,6 +876,7 @@ void generate_data_uniformly(T* data, const size_t size, std::mt19937* gen) { } void fillUniform(Tensor *t) { + // Generate random row-wise data and column-wise data if (t->rowwise()) { const size_t size = product(t->rowwise_shape()); TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(t->dtype(), T, @@ -956,8 +894,12 @@ void fillUniform(Tensor *t) { } ); } - std::uniform_real_distribution<> dis(-2.0, 1.0); - t->set_scale_inv(dis(t->gen())); + + // Generate random scales + t->fill_uniform_rowwise_scale_inv(); + t->fill_uniform_columnwise_scale_inv(); + + // Update data on GPU t->from_cpu(); } @@ -993,7 +935,20 @@ void fillCase_special(Tensor *t) { } }); } - t->set_scale_inv(1.0); + + // Fill scales + if (t->scaling_mode() == NVTE_DELAYED_TENSOR_SCALING) { + if (isFp8Type(t->dtype())) { + // FP8 tensor scale is set to 1 + t->set_scale_inv(1.0); + } + } else { + // Block scales are filled randomly + t->fill_uniform_rowwise_scale_inv(); + t->fill_uniform_columnwise_scale_inv(); + } + + // Update GPU tensor data t->from_cpu(); } @@ -1027,15 +982,12 @@ template void fillCase(Tensor *t, const InputsFillCase fill_case); #endif void setRandomScale(Tensor *t) { - std::uniform_real_distribution<> dis(-2.0, 1.0); - const float scale = dis(t->gen()); - t->set_scale(scale); + t->fill_uniform_scale(); } void setRandomScaleInv(Tensor *t) { - std::uniform_real_distribution<> dis(-2.0, 1.0); - const float scale_inv = dis(t->gen()); - t->set_scale_inv(scale_inv); + t->fill_uniform_rowwise_scale_inv(); + t->fill_uniform_columnwise_scale_inv(); } bool isFp8Type(DType type) { diff --git a/tests/cpp/test_common.h b/tests/cpp/test_common.h index b8389d5833..860fc7d7eb 100644 --- a/tests/cpp/test_common.h +++ b/tests/cpp/test_common.h @@ -6,10 +6,11 @@ #pragma once -#include -#include #include +#include #include +#include + #include #define FP4_TYPE_SUPPORTED (CUDA_VERSION >= 12080) @@ -27,6 +28,11 @@ namespace test { using namespace transformer_engine; +size_t typeToNumBits(DType type); +size_t product(const NVTEShape &shape); +size_t product(const std::vector &shape); +size_t bytes(const NVTEShape& shape, const DType type); + template struct BytesToType {}; @@ -114,7 +120,7 @@ struct TypeInfo { } constexpr static DType dtype = getType(); - constexpr static size_t size = BitsNumber::num_bits;; + constexpr static size_t size = BitsNumber::num_bits; }; class Tensor { @@ -133,7 +139,7 @@ class Tensor { const NVTEScalingMode &mode = NVTE_DELAYED_TENSOR_SCALING) : Tensor(name, nvte_make_shape(shape.data(), shape.size()), type, rowwise, columnwise, mode) {} - Tensor() {} + Tensor() = default; Tensor& operator=(const Tensor &other) = delete; Tensor(const Tensor &other) = delete; @@ -141,42 +147,7 @@ class Tensor { Tensor(Tensor &&other) = default; Tensor& operator=(Tensor &&other) = default; - ~Tensor() { - void *data_ptr = tensor_.dptr(); - void *scale_inv = tensor_.scale_inv(); - void *columnwise_data_ptr = tensor_.get_columnwise_data().data_ptr; - void *columnwise_scale_inv = tensor_.get_columnwise_scale_inv().data_ptr; - void *amax = tensor_.amax(); - void *columnwise_amax_ptr = tensor_.get_columnwise_amax().data_ptr; - void *scale = tensor_.scale(); - if (columnwise_data_ptr == data_ptr) { - columnwise_data_ptr = nullptr; - } - if (columnwise_scale_inv == scale_inv) { - columnwise_scale_inv = nullptr; - } - if (data_ptr != nullptr) { - cudaFree(data_ptr); - } - if (scale_inv != nullptr) { - cudaFree(scale_inv); - } - if (columnwise_data_ptr != nullptr) { - cudaFree(columnwise_data_ptr); - } - if (columnwise_scale_inv != nullptr) { - cudaFree(columnwise_scale_inv); - } - if (amax != nullptr) { - cudaFree(amax); - } - if (columnwise_amax_ptr != nullptr) { - cudaFree(columnwise_amax_ptr); - } - if (scale != nullptr) { - cudaFree(scale); - } - } + ~Tensor() = default; NVTETensor data() const noexcept { return tensor_.data(); } @@ -213,84 +184,57 @@ class Tensor { } template - T *rowwise_cpu_dptr() const { - NVTE_CHECK(TypeInfo::dtype == tensor_.dtype(), "Invalid type!"); + T *rowwise_cpu_dptr() { + NVTE_CHECK(TypeInfo::dtype == data_rowwise_.dtype(), "Invalid type!"); NVTE_CHECK(rowwise_, "Tensor does not have rowwise data!"); - return reinterpret_cast(cpu_data_rowwise_.get()); + return data_rowwise_.cpu_buffer(); } template - T *columnwise_cpu_dptr() const { - NVTE_CHECK(TypeInfo::dtype == tensor_.dtype(), "Invalid type!"); + T *columnwise_cpu_dptr() { + NVTE_CHECK(TypeInfo::dtype == data_columnwise_.dtype(), "Invalid type!"); NVTE_CHECK(columnwise_, "Tensor does not have columnwise data!"); - return reinterpret_cast(cpu_data_columnwise_.get()); + return data_columnwise_.cpu_buffer(); } - float amax() const { - if(amax_cpu_data_) { - to_cpu(); - return *amax_cpu_data_; - } else { - return 0; - } + float amax() { + NVTE_CHECK(amax_rowwise_.size() == 1); + NVTE_CHECK(amax_rowwise_.dtype() == DType::kFloat32); + amax_rowwise_.to_cpu(); + return *amax_rowwise_.cpu_buffer(); } - float amax_columnwise() const { - if(amax_cpu_data_columnwise_) { - to_cpu(); - return *amax_cpu_data_columnwise_; - } else { - return 0; - } + float amax_columnwise() { + NVTE_CHECK(amax_columnwise_.size() == 1); + NVTE_CHECK(amax_columnwise_.dtype() == DType::kFloat32); + amax_columnwise_.to_cpu(); + return *amax_columnwise_.cpu_buffer(); } - float scale() const { - if(scale_cpu_data_) { - NVTE_CHECK(tensor_.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING, "Invalid scaling_mode!"); - to_cpu(); - return *scale_cpu_data_; - } else { - return 1; - } + float scale() { + NVTE_CHECK(scale_.size() == 1); + NVTE_CHECK(scale_.dtype() == DType::kFloat32); + scale_.to_cpu(); + return *scale_.cpu_buffer(); } template T *rowwise_cpu_scale_inv_ptr(){ - if (tensor_.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING){ - NVTE_CHECK(TypeInfo::dtype == DType::kFloat32, "Invalid type!"); - } else if (tensor_.scaling_mode() == NVTE_BLOCK_SCALING_1D || tensor_.scaling_mode() == NVTE_BLOCK_SCALING_2D) { - NVTE_CHECK(TypeInfo::dtype == DType::kFloat32, "Invalid type!"); - } else if (tensor_.scaling_mode() == NVTE_NVFP4_1D_SCALING) { - NVTE_CHECK(TypeInfo::dtype == DType::kFloat8E4M3, "Invalid type!"); - } else { - NVTE_CHECK(TypeInfo::dtype == DType::kByte, "Invalid type!"); - } - to_cpu(); - return reinterpret_cast(rowwise_scale_inv_cpu_data_.get()); + scale_inv_rowwise_.to_cpu(); + return scale_inv_rowwise_.cpu_buffer(); } template T *columnwise_cpu_scale_inv_ptr(){ - if (tensor_.scaling_mode() == NVTE_DELAYED_TENSOR_SCALING){ - NVTE_CHECK(TypeInfo::dtype == DType::kFloat32, "Invalid type!"); - } else if (tensor_.scaling_mode() == NVTE_BLOCK_SCALING_1D || tensor_.scaling_mode() == NVTE_BLOCK_SCALING_2D) { - NVTE_CHECK(TypeInfo::dtype == DType::kFloat32, "Invalid type!"); - } else if (tensor_.scaling_mode() == NVTE_NVFP4_1D_SCALING) { - NVTE_CHECK(TypeInfo::dtype == DType::kFloat8E4M3, "Invalid type!"); - } else { - NVTE_CHECK(TypeInfo::dtype == DType::kByte, "Invalid type!"); - } - to_cpu(); - return reinterpret_cast(columnwise_scale_inv_cpu_data_.get()); + scale_inv_columnwise_.to_cpu(); + return scale_inv_columnwise_.cpu_buffer(); } float rowwise_scale_inv(){ - if(rowwise_scale_inv_cpu_data_) { - float scale_inv = rowwise_cpu_scale_inv_ptr()[0]; - return scale_inv; - } else { - return 1; - } + NVTE_CHECK(scale_inv_rowwise_.size() == 1); + NVTE_CHECK(scale_inv_rowwise_.dtype() == DType::kFloat32); + scale_inv_rowwise_.to_cpu(); + return *scale_inv_rowwise_.cpu_buffer(); } bool rowwise() const { @@ -301,20 +245,6 @@ class Tensor { return columnwise_; } - void set_tensor_amax(const float amax) { - if (amax_cpu_data_) { - *amax_cpu_data_ = amax; - from_cpu(); - } - } - - void set_tensor_amax_columnwise(const float amax) { - if (amax_cpu_data_columnwise_) { - *amax_cpu_data_columnwise_ = amax; - from_cpu(); - } - } - void set_tensor_amax_nullptr(){ tensor_.set_amax(nullptr, DType::kFloat32, tensor_.defaultShape); } @@ -323,23 +253,90 @@ class Tensor { tensor_.set_with_gemm_swizzled_scales(with_gemm_swizzled_scales); } - void to_cpu() const; - void from_cpu() const; + void to_cpu(); + void from_cpu(); + + void set_amax(float amax); void set_scale(float scale); void set_scale_inv(float scale_inv); - void shareFP8Meta(const Tensor &other); + void set_tensor_amax(float amax); + void set_tensor_amax_columnwise(float amax); + + void fill_uniform_rowwise_scale_inv(); + void fill_uniform_columnwise_scale_inv(); + void fill_uniform_scale(); std::mt19937& gen() { return gen_; } private: + + /* Manages matching GPU and CPU buffers. */ + class Buffer { + public: + + Buffer(size_t size = 0, DType dtype = DType::kByte); + ~Buffer() = default; + Buffer(const Buffer&) = delete; + Buffer& operator=(const Buffer&) = delete; + Buffer(Buffer&&) = default; + Buffer& operator=(Buffer&&) = default; + + size_t size() const noexcept { return size_; } + DType dtype() const noexcept { return dtype_; } + + // Void pointer accessors + void *cpu_buffer() { return cpu_buffer_.get(); } + const void *cpu_buffer() const { return cpu_buffer_.get(); } + void *gpu_buffer() { return gpu_buffer_.get(); } + const void *gpu_buffer() const { return gpu_buffer_.get(); } + + // Templated pointer accessors + template + T *cpu_buffer() { + return reinterpret_cast(cpu_buffer()); + } + template + const T *cpu_buffer() const { + return const_cast(this)->cpu_buffer(); + } + template + T *gpu_buffer() { + return reinterpret_cast(gpu_buffer()); + } + template + const T *gpu_buffer() const { + return const_cast(this)->gpu_buffer(); + } + + // Memory transfers between CPU and GPU + void to_cpu(); + void from_cpu(); + + private: + + struct GPUDeleter { + void operator()(void *ptr); + }; + + std::unique_ptr cpu_buffer_; + std::unique_ptr gpu_buffer_; + size_t size_; + DType dtype_; + size_t bytes_; + }; + + // Transformer Engine tensor TensorWrapper tensor_; - std::unique_ptr cpu_data_rowwise_; - std::unique_ptr cpu_data_columnwise_; - std::shared_ptr amax_cpu_data_; - std::shared_ptr amax_cpu_data_columnwise_; - std::shared_ptr scale_cpu_data_; - std::unique_ptr rowwise_scale_inv_cpu_data_; - std::unique_ptr columnwise_scale_inv_cpu_data_; + + // Data buffers + Buffer data_rowwise_; + Buffer data_columnwise_; + Buffer scale_inv_rowwise_; + Buffer scale_inv_columnwise_; + Buffer amax_rowwise_; + Buffer amax_columnwise_; + Buffer scale_; + bool rowwise_; bool columnwise_; std::string name_; @@ -489,17 +486,12 @@ inline float dsilu(const float x) { return x * dsigmoid(x) + sigmoid(x); } inline float srelu(const float x) { return x > 0 ? x * x : 0; } inline float dsrelu(const float x) { return fmaxf(0, 2 * x); } -size_t typeToNumBits(DType type); -size_t product(const NVTEShape &shape); -size_t product(const std::vector &shape); -size_t bytes(const NVTEShape& shape, const DType type); - size_t first_dimension(const std::vector &shape); size_t last_dimension(const std::vector &shape); bool areShapesEqual(const NVTEShape &s1, const NVTEShape &s2); -void compareResults(const std::string &name, const Tensor &test, const void *ref, +void compareResults(const std::string &name, Tensor &test, const void *ref, bool rowwise, double atol = 1e-5, double rtol = 1e-8, bool if_on_gpus = true, const size_t tolerable_mismatches_limit = 0); void compareResults(const std::string &name, const float test, const float ref, diff --git a/tests/cpp_distributed/test_comm_gemm.cu b/tests/cpp_distributed/test_comm_gemm.cu index cc0d760a39..45f6664567 100644 --- a/tests/cpp_distributed/test_comm_gemm.cu +++ b/tests/cpp_distributed/test_comm_gemm.cu @@ -107,8 +107,10 @@ std::vector CopyMatrix(const std::vector& data, size_t mstart, size_t nsta template test::Tensor Make(size_t m, size_t n, float scale) { test::Tensor ret("", std::vector{n, m}, TypeInfo::dtype); - ret.set_scale(scale); - ret.set_scale_inv(1.0 / scale); + if (test::isFp8Type(TypeInfo::dtype)) { + ret.set_scale(scale); + ret.set_scale_inv(1.0 / scale); + } return ret; } @@ -116,8 +118,10 @@ template test::Tensor MakeFromData(const std::vector& data, size_t mstart, size_t nstart, size_t msize, size_t nsize, size_t ld, float scale) { test::Tensor ret("", std::vector{nsize, msize}, TypeInfo::dtype); - ret.set_scale(scale); - ret.set_scale_inv(1.0 / scale); + if (test::isFp8Type(TypeInfo::dtype)) { + ret.set_scale(scale); + ret.set_scale_inv(1.0 / scale); + } auto local = CopyMatrix(data, mstart, nstart, msize, nsize, ld); NVTE_CHECK_CUDA(cudaMemcpy(ret.rowwise_dptr(), local.data(), local.size() * sizeof local[0], cudaMemcpyDefault));