Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 24 additions & 14 deletions transformer_engine/pytorch/csrc/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,9 @@ class Quantizer {

/*! @brief Construct a tensor with uninitialized data */
virtual std::pair<TensorWrapper, py::object> create_tensor(const std::vector<size_t>& shape,
DType dtype) const = 0;
DType dtype,
at::Device device = torch::kCUDA,
bool pin_memory = false) const = 0;

/*! @brief Convert a PyTorch tensor into a Transformer Engine C++ tensor
*
Expand Down Expand Up @@ -135,8 +137,9 @@ class NoneQuantizer : public Quantizer {

void set_quantization_params(TensorWrapper* tensor) const override {}

std::pair<TensorWrapper, py::object> create_tensor(const std::vector<size_t>& shape,
DType dtype) const override;
std::pair<TensorWrapper, py::object> create_tensor(const std::vector<size_t>& shape, DType dtype,
at::Device device = torch::kCUDA,
bool pin_memory = false) const override;

/*! @brief Construct a tensor with pre-initialized data */
std::pair<TensorWrapper, py::object> create_tensor(const std::vector<size_t>& shape, DType dtype,
Expand All @@ -161,14 +164,17 @@ class Float8Quantizer : public Quantizer {

void set_quantization_params(TensorWrapper* tensor) const override;

std::pair<TensorWrapper, py::object> create_tensor(const std::vector<size_t>& shape,
DType dtype) const override;
std::pair<TensorWrapper, py::object> create_tensor(const std::vector<size_t>& shape, DType dtype,
at::Device device = torch::kCUDA,
bool pin_memory = false) const override;

/*! @brief Construct a tensor with pre-initialized data */
std::pair<TensorWrapper, py::object> create_tensor(const std::vector<size_t>& shape, DType dtype,
std::optional<at::Tensor> data,
std::optional<at::Tensor> transpose,
std::optional<at::Tensor> scale_inv) const;
std::optional<at::Tensor> scale_inv,
at::Device device = torch::kCUDA,
bool pin_memory = false) const;

std::pair<TensorWrapper, py::object> convert_and_update_tensor(py::object shape) const override;

Expand All @@ -193,8 +199,9 @@ class Float8CurrentScalingQuantizer : public Quantizer {

void set_quantization_params(TensorWrapper* tensor) const override;

std::pair<TensorWrapper, py::object> create_tensor(const std::vector<size_t>& shape,
DType dtype) const override;
std::pair<TensorWrapper, py::object> create_tensor(const std::vector<size_t>& shape, DType dtype,
at::Device device = torch::kCUDA,
bool pin_memory = false) const override;

/*! @brief Construct an unquantized tensor that shares the quantizer's amax pointer.
*
Expand Down Expand Up @@ -250,8 +257,9 @@ class Float8BlockQuantizer : public Quantizer {
// Create a python Float8BlockQuantized tensor and C++ wrapper
// for the tensor. Should set quantized data, scales for rowwise
// and optionally columnwise usage.
std::pair<TensorWrapper, py::object> create_tensor(const std::vector<size_t>& shape,
DType dtype) const override;
std::pair<TensorWrapper, py::object> create_tensor(const std::vector<size_t>& shape, DType dtype,
at::Device device = torch::kCUDA,
bool pin_memory = false) const override;

std::pair<TensorWrapper, py::object> convert_and_update_tensor(py::object shape) const override;

Expand All @@ -271,8 +279,9 @@ class MXFP8Quantizer : public Quantizer {

void set_quantization_params(TensorWrapper* tensor) const override;

std::pair<TensorWrapper, py::object> create_tensor(const std::vector<size_t>& shape,
DType dtype) const override;
std::pair<TensorWrapper, py::object> create_tensor(const std::vector<size_t>& shape, DType dtype,
at::Device device = torch::kCUDA,
bool pin_memory = false) const override;

std::pair<TensorWrapper, py::object> convert_and_update_tensor(py::object shape) const override;

Expand Down Expand Up @@ -305,8 +314,9 @@ class NVFP4Quantizer : public Quantizer {

void set_quantization_params(TensorWrapper* tensor) const override;

std::pair<TensorWrapper, py::object> create_tensor(const std::vector<size_t>& shape,
DType dtype) const override;
std::pair<TensorWrapper, py::object> create_tensor(const std::vector<size_t>& shape, DType dtype,
at::Device device = torch::kCUDA,
bool pin_memory = false) const override;

/*! @brief Construct an unquantized tensor that shares NVFP4 tensor's amax pointer
*
Expand Down
3 changes: 3 additions & 0 deletions transformer_engine/pytorch/csrc/extensions.h
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,9 @@ py::object quantize(const at::Tensor &tensor, py::handle quantizer, const py::ob

py::object dequantize(const py::handle &input, DType otype);

py::object create_empty_quantized_tensor(py::handle quantizer, const std::vector<size_t> &shape,
at::ScalarType dtype, at::Device device, bool pin_memory);

std::vector<py::object> multi_tensor_quantize(const std::vector<at::Tensor> &tensor_list,
std::vector<py::handle> quantizer_list);

Expand Down
8 changes: 8 additions & 0 deletions transformer_engine/pytorch/csrc/extensions/cast.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,14 @@ py::object quantize(const at::Tensor &tensor, py::handle quantizer, const py::ob
return output_py;
}

py::object create_empty_quantized_tensor(py::handle quantizer, const std::vector<size_t> &shape,
at::ScalarType dtype, at::Device device, bool pin_memory) {
auto quantizer_cpp = convert_quantizer(quantizer);
auto te_dtype = GetTransformerEngineDType(dtype);
auto [_, output_py] = quantizer_cpp->create_tensor(shape, te_dtype, device, pin_memory);
return output_py;
}

py::object dequantize(const py::handle &input, transformer_engine::DType otype) {
init_extension();

Expand Down
4 changes: 4 additions & 0 deletions transformer_engine/pytorch/csrc/extensions/pybind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
py::arg("output") = py::none(), py::arg("noop") = py::none());
m.def("dequantize", &transformer_engine::pytorch::dequantize, "Dequantize", py::arg("input"),
py::arg("otype"));
m.def("create_empty_quantized_tensor",
&transformer_engine::pytorch::create_empty_quantized_tensor,
"Create an empty quantized tensor", py::arg("quantizer"), py::arg("shape"),
py::arg("dtype"), py::arg("device"), py::arg("pin_memory"));

m.def("bgrad_quantize", transformer_engine::pytorch::bgrad_quantize,
"Compute bias gradient and quantize", py::arg("input"), py::arg("quantizer"));
Expand Down
55 changes: 35 additions & 20 deletions transformer_engine/pytorch/csrc/quantizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,11 @@ Float8Quantizer::Float8Quantizer(const py::handle& quantizer) : Quantizer(quanti
}

std::pair<TensorWrapper, py::object> NoneQuantizer::create_tensor(const std::vector<size_t>& shape,
DType dtype) const {
DType dtype, at::Device device,
bool pin_memory) const {
const std::vector<int64_t> shape_int64(shape.begin(), shape.end());
const auto opts = at::TensorOptions().dtype(GetATenDType(dtype)).device(torch::kCUDA);
const auto opts =
at::TensorOptions().dtype(GetATenDType(dtype)).device(device).pinned_memory(pin_memory);
return create_tensor(shape, dtype, at::empty(shape_int64, opts));
}

Expand Down Expand Up @@ -113,22 +115,26 @@ void Float8Quantizer::set_quantization_params(TensorWrapper* tensor) const {
}

std::pair<TensorWrapper, py::object> Float8Quantizer::create_tensor(
const std::vector<size_t>& shape, DType dtype) const {
const auto opts = at::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA);
const std::vector<size_t>& shape, DType dtype, at::Device device, bool pin_memory) const {
const auto opts =
at::TensorOptions().dtype(torch::kFloat32).device(device).pinned_memory(pin_memory);
at::Tensor scale_inv = at::empty(std::vector<int64_t>{1}, opts);
return create_tensor(shape, dtype, std::nullopt, std::nullopt, std::move(scale_inv));
return create_tensor(shape, dtype, std::nullopt, std::nullopt, std::move(scale_inv), device,
pin_memory);
}

std::pair<TensorWrapper, py::object> Float8Quantizer::create_tensor(
const std::vector<size_t>& shape, DType dtype, std::optional<at::Tensor> data,
std::optional<at::Tensor> transpose, std::optional<at::Tensor> scale_inv) const {
std::optional<at::Tensor> transpose, std::optional<at::Tensor> scale_inv, at::Device device,
bool pin_memory) const {
using namespace pybind11::literals;

// Initialize data tensor
const bool with_data = rowwise_usage || nvte_is_non_tn_fp8_gemm_supported();
if (with_data && !data) {
const std::vector<int64_t> shape_int64(shape.begin(), shape.end());
const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA);
const auto opts =
at::TensorOptions().dtype(torch::kUInt8).device(device).pinned_memory(pin_memory);
data = at::empty(shape_int64, opts);
} else if (!with_data && data) {
data.reset();
Expand All @@ -139,7 +145,8 @@ std::pair<TensorWrapper, py::object> Float8Quantizer::create_tensor(
const bool with_transpose = columnwise_usage && !nvte_is_non_tn_fp8_gemm_supported();
if (with_transpose && !transpose) {
const auto transpose_shape = make_transpose_shape<int64_t>(shape);
const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA);
const auto opts =
at::TensorOptions().dtype(torch::kUInt8).device(device).pinned_memory(pin_memory);
transpose = at::empty(transpose_shape, opts);
} else if (!with_transpose && transpose) {
transpose.reset();
Expand Down Expand Up @@ -325,15 +332,16 @@ void Float8CurrentScalingQuantizer::set_quantization_params(TensorWrapper* tenso
}

std::pair<TensorWrapper, py::object> Float8CurrentScalingQuantizer::create_tensor(
const std::vector<size_t>& shape, DType dtype) const {
const std::vector<size_t>& shape, DType dtype, at::Device device, bool pin_memory) const {
using namespace pybind11::literals;

// Initialize data tensor
at::Tensor data_tensor;
const bool with_data = rowwise_usage || nvte_is_non_tn_fp8_gemm_supported();
if (with_data) {
const std::vector<int64_t> shape_int64(shape.begin(), shape.end());
const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA);
const auto opts =
at::TensorOptions().dtype(torch::kUInt8).device(device).pinned_memory(pin_memory);
data_tensor = at::empty(shape_int64, opts);
}

Expand All @@ -342,15 +350,17 @@ std::pair<TensorWrapper, py::object> Float8CurrentScalingQuantizer::create_tenso
const bool with_transpose = columnwise_usage && !nvte_is_non_tn_fp8_gemm_supported();
if (with_transpose) {
const auto transpose_shape = make_transpose_shape<int64_t>(shape);
const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA);
const auto opts =
at::TensorOptions().dtype(torch::kUInt8).device(device).pinned_memory(pin_memory);
transpose_tensor = at::empty(transpose_shape, opts);
}

// Initialize scale-inverse tensor
at::Tensor scale_inv_tensor;
{
const std::vector<int64_t> scale_inv_shape = {1};
const auto opts = at::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA);
const auto opts =
at::TensorOptions().dtype(torch::kFloat32).device(device).pinned_memory(pin_memory);
scale_inv_tensor = at::empty(scale_inv_shape, opts);
}

Expand Down Expand Up @@ -562,7 +572,7 @@ Float8BlockQuantizer::Float8BlockQuantizer(const py::handle& quantizer) : Quanti
void Float8BlockQuantizer::set_quantization_params(TensorWrapper* tensor) const {}

std::pair<TensorWrapper, py::object> Float8BlockQuantizer::create_tensor(
const std::vector<size_t>& shape, DType dtype) const {
const std::vector<size_t>& shape, DType dtype, at::Device device, bool pin_memory) const {
using namespace pybind11::literals;
std::vector<int64_t> torch_shape;
for (auto s : shape) {
Expand All @@ -573,8 +583,8 @@ std::pair<TensorWrapper, py::object> Float8BlockQuantizer::create_tensor(
at::TensorOptions opts;
at::TensorOptions scale_opts;
at::Tensor data_rowwise, data_colwise, scale_inv_rowwise, scale_inv_colwise;
opts = opts.dtype(torch::kUInt8).device(torch::kCUDA);
scale_opts = scale_opts.dtype(torch::kFloat32).device(torch::kCUDA);
opts = opts.dtype(torch::kUInt8).device(device).pinned_memory(pin_memory);
scale_opts = scale_opts.dtype(torch::kFloat32).device(device).pinned_memory(pin_memory);

if (rowwise_usage) {
data_rowwise = at::empty(torch_shape, opts);
Expand Down Expand Up @@ -858,7 +868,8 @@ MXFP8Quantizer::MXFP8Quantizer(const py::handle& quantizer) : Quantizer(quantize
void MXFP8Quantizer::set_quantization_params(TensorWrapper* tensor) const {}

std::pair<TensorWrapper, py::object> MXFP8Quantizer::create_tensor(const std::vector<size_t>& shape,
DType dtype) const {
DType dtype, at::Device device,
bool pin_memory) const {
using namespace pybind11::literals;

// Scaling factor format
Expand All @@ -882,7 +893,8 @@ std::pair<TensorWrapper, py::object> MXFP8Quantizer::create_tensor(const std::ve
// Allocate tensors
at::Tensor rowwise_data_tensor, rowwise_scale_inv_tensor;
at::Tensor columnwise_data_tensor, columnwise_scale_inv_tensor;
const auto uint8_tensor_opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA);
const auto uint8_tensor_opts =
at::TensorOptions().dtype(torch::kUInt8).device(device).pinned_memory(pin_memory);
if (rowwise_usage) {
const std::vector<int64_t> scale_inv_shape_int64(rowwise_scale_inv_shape.begin(),
rowwise_scale_inv_shape.end());
Expand Down Expand Up @@ -1132,7 +1144,8 @@ void NVFP4Quantizer::set_quantization_params(TensorWrapper* tensor) const {
}

std::pair<TensorWrapper, py::object> NVFP4Quantizer::create_tensor(const std::vector<size_t>& shape,
DType dtype) const {
DType dtype, at::Device device,
bool pin_memory) const {
using namespace pybind11::literals;

// Scaling factor format
Expand All @@ -1158,8 +1171,10 @@ std::pair<TensorWrapper, py::object> NVFP4Quantizer::create_tensor(const std::ve
// Allocate tensors
at::Tensor rowwise_data_tensor, rowwise_scale_inv_tensor, amax_rowwise;
at::Tensor columnwise_data_tensor, columnwise_scale_inv_tensor, amax_columnwise;
const auto bit8_tensor_opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA);
const auto bit32_tensor_opts = at::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA);
const auto bit8_tensor_opts =
at::TensorOptions().dtype(torch::kUInt8).device(device).pinned_memory(pin_memory);
const auto bit32_tensor_opts =
at::TensorOptions().dtype(torch::kFloat32).device(device).pinned_memory(pin_memory);
if (rowwise_usage) {
const std::vector<int64_t> scale_inv_shape_int64(rowwise_scale_inv_shape.begin(),
rowwise_scale_inv_shape.end());
Expand Down
23 changes: 19 additions & 4 deletions transformer_engine/pytorch/quantized_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
import torch
from torch.utils._pytree import tree_map

import transformer_engine_torch as tex

from transformer_engine.common.recipe import Recipe
from transformer_engine.pytorch.tensor._quantization_helpers import (
_QuantizeFunc,
Expand Down Expand Up @@ -272,13 +274,26 @@ def make_empty(
shape: Iterable[int],
*,
dtype: torch.dtype = torch.float32,
device: Optional[torch.device] = None,
device: Optional[Union[torch.device, str]] = None,
requires_grad: bool = False,
pin_memory: bool = False,
) -> QuantizedTensor:
"""Construct quantized tensor with uninitialized data"""
raise NotImplementedError(
f"{self.__class__.__name__} class does not implement make_empty function, "
"required for construction of unintialized quantized tensor"

if device is None:
device = torch.device("cuda")
# Handle the device passed as string
device = torch.device(device)
result = tex.create_empty_quantized_tensor(
self,
list(shape),
dtype,
device,
pin_memory,
)
if requires_grad:
result.requires_grad_(True)
return result

def calibrate(self, tensor: torch.Tensor) -> None:
"""Calibrate quantizer state
Expand Down
56 changes: 0 additions & 56 deletions transformer_engine/pytorch/tensor/float8_blockwise_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,62 +202,6 @@ def is_quantizable(self, inp: torch.Tensor) -> bool:
return False
return True

def make_empty(
self,
shape: Iterable[int],
*,
dtype: torch.dtype = torch.float32,
device: Optional[torch.device] = None,
requires_grad: bool = False,
pin_memory: bool = False,
) -> Float8BlockwiseQTensor:
"""Construct quantized tensor with uninitialized data"""

tensor_kwargs = {
"device": torch.device("cuda") if device is None else device,
"pin_memory": pin_memory,
}

# Allocate buffers for row-scaled data
rowwise_data = None
rowwise_scale_inv = None
if self.rowwise_usage:
rowwise_data = torch.empty(shape, dtype=torch.uint8, **tensor_kwargs)
rowwise_scale_inv = torch.empty(
self.get_scale_shape(shape, columnwise=False),
dtype=torch.float32,
**tensor_kwargs,
)

# Allocate buffers for column-scaled data
columnwise_data = None
columnwise_scale_inv = None
if self.columnwise_usage:
columnwise_data = torch.empty(
self.get_columnwise_shape(shape),
dtype=torch.uint8,
**tensor_kwargs,
)
columnwise_scale_inv = torch.empty(
self.get_scale_shape(shape, columnwise=True),
dtype=torch.float32,
**tensor_kwargs,
)

# Construct FP8 tensor
return Float8BlockwiseQTensor(
shape=shape,
dtype=dtype,
fp8_dtype=self.dtype,
rowwise_data=rowwise_data,
rowwise_scale_inv=rowwise_scale_inv,
columnwise_data=columnwise_data,
columnwise_scale_inv=columnwise_scale_inv,
quantizer=self,
is_2D_scaled=self.block_scaling_dim == 2,
requires_grad=requires_grad,
)

def calibrate(self, tensor: torch.Tensor) -> None:
# NOTE: This interface is specific to requirements like delayed scaling
# where state from an estimator influences distribution parameters.
Expand Down
Loading
Loading