diff --git a/tests/cpp/operator/test_grouped_gemm.cu b/tests/cpp/operator/test_grouped_gemm.cu index a694052b15..73969ca297 100644 --- a/tests/cpp/operator/test_grouped_gemm.cu +++ b/tests/cpp/operator/test_grouped_gemm.cu @@ -88,16 +88,16 @@ struct TestParams { std::vector> make_shapes(ShapeCase scase) { switch (scase) { case ShapeCase::kAllSame: - return {{64, 64, 32}, {64, 64, 32}, {64, 64, 32}}; + return {{128, 256, 384}, {128, 256, 384}, {128, 256, 384}}; case ShapeCase::kSameFirst: // Same M (first dim), varying N and K - return {{64, 80, 32}, {64, 96, 48}, {64, 112, 64}}; + return {{128, 256, 384}, {128, 384, 512}, {128, 512, 640}}; case ShapeCase::kSameLast: // Same N (last dim), varying M and K - return {{64, 80, 32}, {80, 80, 48}, {96, 80, 64}}; + return {{128, 256, 384}, {256, 256, 512}, {384, 256, 640}}; case ShapeCase::kAllDifferent: default: - return {{64, 96, 32}, {80, 112, 48}, {96, 128, 64}}; + return {{128, 256, 384}, {256, 384, 512}, {384, 512, 640}}; } } @@ -123,10 +123,11 @@ void run_grouped_gemm_case(const TestParams& params) { for (size_t i = 0; i < num_gemms; ++i) { const auto [M, N, K] = shapes[i]; - const std::vector a_shape = params.transa ? std::vector{M, K} - : std::vector{K, M}; - const std::vector b_shape = params.transb ? std::vector{K, N} - : std::vector{N, K}; + + const std::vector a_shape = params.transa ? std::vector{N, K} + : std::vector{K, N}; + const std::vector b_shape = params.transb ? std::vector{K, M} + : std::vector{M, K}; switch (params.input_case) { case InputCase::kFP8Current: { A_tensors.emplace_back(make_fp8_operand("A" + std::to_string(i), a_shape)); @@ -247,6 +248,8 @@ void run_grouped_gemm_case(const TestParams& params) { nullptr, // config (use defaults) 0); + NVTE_CHECK_CUDA(cudaDeviceSynchronize()); + // Compare results for (size_t i = 0; i < num_gemms; ++i) { Tensor grouped_split("grouped_D" + std::to_string(i), std::vector{static_cast(std::get<0>(shapes[i])), @@ -289,7 +292,6 @@ std::string MakeGroupedGemmTestName(const testing::TestParamInfo kTestParams = { // Basic tests - {InputCase::kFP8Current, true, false, ShapeCase::kAllDifferent, false}, {InputCase::kFP8Current, false, true, ShapeCase::kAllDifferent, false}, {InputCase::kFP8Current, false, false, ShapeCase::kAllSame, false}, {InputCase::kBF16, true, false, ShapeCase::kSameFirst, false}, diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index abe2806e66..366b2a62c5 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -46,7 +46,12 @@ is_nvfp4_available, ) from transformer_engine.pytorch import checkpoint as te_checkpoint -from transformer_engine.pytorch.cpp_extensions import general_gemm, general_grouped_gemm +from transformer_engine.pytorch.cpp_extensions import ( + general_gemm, + general_grouped_gemm, + general_grouped_gemm_for_grouped_tensor, +) +from transformer_engine.pytorch.tensor.storage.grouped_tensor import GroupedTensor from transformer_engine.common import recipe import transformer_engine_torch as tex from utils import ModelConfig, reset_rng_states @@ -1991,6 +1996,82 @@ def test_grouped_linear_accuracy( torch.testing.assert_close(o, o_ref, rtol=0, atol=0) +@pytest.mark.parametrize("single_weight", [True, False], ids=["single_weight", "multi_weight"]) +def test_grouped_linear_m_splits_tensor(single_weight): + """Test GroupedLinear with m_splits as torch tensor (no_quantization/bf16). + grouped_tensor_path is chosen and must match reference (single_weight vs reference model, + or multi_weight list m_splits vs tensor m_splits). + """ + if tex.get_cublasLt_version() < 130200: + pytest.skip("Grouped GEMM requires cuBLAS 13.2+.") + if torch.cuda.get_device_capability() < (10, 0): + pytest.skip("Grouped GEMM requires Blackwell (SM100) or newer.") + if not is_bf16_available(): + pytest.skip("bfloat16 is required for grouped GEMM test.") + + torch.manual_seed(0) + num_gemms = 3 + in_features = 32 + out_features = 64 + m_splits = torch.tensor([5, 7, 9], device="cuda", dtype=torch.int64) + m_splits_list = [5, 7, 9] + dtype = torch.bfloat16 + m_total = int(m_splits.sum().item()) + + reference_model = GroupedLinear( + num_gemms, + in_features, + out_features, + bias=False, + params_dtype=dtype, + device="cuda", + single_weight=False, + ) + with torch.no_grad(): + ref_weights = [getattr(reference_model, f"weight{i}") for i in range(num_gemms)] + + test_model = GroupedLinear( + num_gemms, + in_features, + out_features, + bias=False, + params_dtype=dtype, + device="cuda", + single_weight=single_weight, + ) + with torch.no_grad(): + if single_weight: + for i, w in enumerate(test_model.grouped_weight_storage.split_into_quantized_tensors()): + w.copy_(ref_weights[i]) + else: + for i in range(num_gemms): + getattr(test_model, f"weight{i}").copy_(ref_weights[i]) + + inp = torch.randn(m_total, in_features, device="cuda", dtype=dtype, requires_grad=True) + inp_ref = inp.detach().clone().requires_grad_() + + if single_weight: + out = test_model(inp, m_splits) + out_ref = reference_model(inp_ref, m_splits) + else: + out = test_model(inp, m_splits) + out_ref = reference_model(inp_ref, m_splits_list) + + torch.testing.assert_close(out, out_ref, **dtype_tols(dtype)) + + out.sum().backward() + out_ref.sum().backward() + + torch.testing.assert_close(inp.grad, inp_ref.grad, **dtype_tols(dtype)) + if single_weight: + ref_wgrad = torch.cat( + [getattr(reference_model, f"weight{i}").grad.view(-1) for i in range(num_gemms)] + ) + torch.testing.assert_close( + getattr(test_model, "weight0").grad, ref_wgrad, **dtype_tols(dtype) + ) + + @pytest.mark.skipif( torch.cuda.get_device_capability() != (9, 0), reason="Only enable CUTLASS grouped gemm on Hopper", @@ -2790,6 +2871,124 @@ def test_grouped_gemm(shape, dtype, layout, accumulate, use_cutlass): os.environ.pop("NVTE_USE_CUTLASS_GROUPED_GEMM", None) +def _pack_grouped_tensor(grouped_tensor: GroupedTensor, tensors: List[torch.Tensor]) -> None: + offset = 0 + for tensor in tensors: + numel = tensor.numel() + grouped_tensor.data[offset : offset + numel].copy_(tensor.reshape(-1)) + offset += numel + + +@pytest.mark.parametrize("layout", ["TN", "NN", "NT"]) +@pytest.mark.parametrize("accumulate", [False]) +def test_grouped_gemm_grouped_tensor(layout, accumulate): + if tex.get_cublasLt_version() < 130200: + pytest.skip("Grouped GEMM requires cuBLAS 13.2+.") + if torch.cuda.get_device_capability() < (10, 0): + pytest.skip("Grouped GEMM requires Blackwell (SM100) or newer.") + if not is_bf16_available(): + pytest.skip("bfloat16 is required for grouped GEMM test.") + + torch.manual_seed(0) + z, m, k, n = (4, 512, 256, 256) + + split_points = torch.randperm(m - 1)[: z - 1] + 1 + split_points = torch.sort(split_points).values.tolist() + m_sizes = [split_points[0]] + m_sizes += [b - a for a, b in zip(split_points[:-1], split_points[1:])] + m_sizes.append(m - split_points[-1]) + assert sum(m_sizes) == m and len(m_sizes) == z + + dtype = torch.bfloat16 + + if layout == "TN": + A = [torch.randn(n, k, dtype=dtype, device="cuda") for _ in range(z)] # weight + B = [torch.randn(ms, k, dtype=dtype, device="cuda") for ms in m_sizes] # input + out = [torch.randn(ms, n, dtype=dtype, device="cuda") for ms in m_sizes] # output + grad = False + + elif layout == "NN": + A = [torch.randn(n, k, dtype=dtype, device="cuda") for _ in range(z)] # weight + B = [torch.randn(ms, n, dtype=dtype, device="cuda") for ms in m_sizes] # grad_output + out = [torch.randn(ms, k, dtype=dtype, device="cuda") for ms in m_sizes] # dgrad + grad = True + else: # layout == "NT" + A = [torch.randn(ms, k, dtype=dtype, device="cuda") for ms in m_sizes] # input + B = [torch.randn(ms, n, dtype=dtype, device="cuda") for ms in m_sizes] # grad_output + out = [torch.randn(n, k, dtype=dtype, device="cuda") for _ in range(z)] # wgrad + grad = True + + out_ref = [o.clone() for o in out] + general_grouped_gemm( + A, + B, + out_ref, + [None] * z, + dtype, + m_splits=m_sizes, + grad=grad, + accumulate=accumulate, + layout=layout, + single_output=False, + ) + + device = A[0].device + + def _make_grouped_tensor_from_splits(m_sizes, last_dim): + first_dims = torch.tensor(m_sizes, device=device, dtype=torch.int64) + return GroupedTensor.make_grouped_tensor( + num_tensors=len(m_sizes), + first_dims=first_dims, + last_dims=None, + logical_first_dim=sum(m_sizes), + logical_last_dim=last_dim, + quantizer=None, + device=device, + dtype=dtype, + ) + + def _make_grouped_tensor_uniform(num_tensors, first_dim, last_dim): + return GroupedTensor.make_grouped_tensor( + num_tensors=num_tensors, + first_dims=None, + last_dims=None, + logical_first_dim=num_tensors * first_dim, + logical_last_dim=last_dim, + quantizer=None, + device=device, + dtype=dtype, + ) + + if layout == "TN": + grouped_A = _make_grouped_tensor_uniform(z, n, k) + grouped_B = _make_grouped_tensor_from_splits(m_sizes, k) + grouped_out = _make_grouped_tensor_from_splits(m_sizes, n) + elif layout == "NN": + grouped_A = _make_grouped_tensor_uniform(z, n, k) + grouped_B = _make_grouped_tensor_from_splits(m_sizes, n) + grouped_out = _make_grouped_tensor_from_splits(m_sizes, k) + else: # layout == "NT" + grouped_A = _make_grouped_tensor_from_splits(m_sizes, k) + grouped_B = _make_grouped_tensor_from_splits(m_sizes, n) + grouped_out = _make_grouped_tensor_uniform(z, n, k) + _pack_grouped_tensor(grouped_A, A) + _pack_grouped_tensor(grouped_B, B) + _pack_grouped_tensor(grouped_out, out) + + general_grouped_gemm_for_grouped_tensor( + grouped_A, + grouped_B, + grouped_out, + layout=layout, + accumulate=accumulate, + ) + + out_grouped = grouped_out.split_into_quantized_tensors() + tols = dtype_tols(dtype) + for o, o_ref in zip(out_grouped, out_ref): + torch.testing.assert_close(o, o_ref, **tols) + + @pytest.mark.parametrize("N", [32]) @pytest.mark.parametrize("datatype", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize( diff --git a/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu b/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu index b3e216dc4f..f1333f3491 100644 --- a/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu @@ -11,6 +11,7 @@ #include #include +#include #include "../common.h" #include "../util/cuda_runtime.h" @@ -138,7 +139,6 @@ struct GroupedGemmSetupWorkspace { offset += ptr_size; ws.beta_ptrs = reinterpret_cast(setup_ws_ptr + offset); offset += ptr_size; - // Int arrays for storage dimensions (4-byte aligned) ws.a_rows = reinterpret_cast(setup_ws_ptr + offset); offset += int_size; @@ -487,9 +487,9 @@ __global__ void setup_grouped_gemm_kernel( a_cols[idx] = static_cast(a_first); b_rows[idx] = static_cast(b_last); b_cols[idx] = static_cast(b_first); - // For OUTPUTS (D, C): cuBLAS writes in column-major, so rows=first (M), cols=last (N). - d_rows[idx] = static_cast(d_first); - d_cols[idx] = static_cast(d_last); + + d_rows[idx] = static_cast(d_last); + d_cols[idx] = static_cast(d_first); // Fill alpha/beta pointers (per-matrix) alpha_ptrs[idx] = alpha_ptr + idx; diff --git a/transformer_engine/common/include/transformer_engine/transformer_engine.h b/transformer_engine/common/include/transformer_engine/transformer_engine.h index ae41f238a4..1e56ecaa9e 100644 --- a/transformer_engine/common/include/transformer_engine/transformer_engine.h +++ b/transformer_engine/common/include/transformer_engine/transformer_engine.h @@ -957,6 +957,214 @@ class TensorWrapper { NVTETensor tensor_ = nullptr; }; +/*! \struct GroupedTensorWrapper + * \brief C++ wrapper for the NVTEGroupedTensor class. + */ + +class GroupedTensorWrapper { + public: + /*! \brief Constructs new GroupedTensorWrapper. + * + * Create a new TE grouped tensor with a given logical shape. + * TE grouped tensors are just wrappers on top of raw data and do not + * own memory. + * + * \param[in] num_tensors Number of tensors in the group (must be > 0). + * \param[in] logical_shape Logical 2D shape of the grouped data. + * \param[in] scaling_mode Tensor data format. + */ + GroupedTensorWrapper(const size_t num_tensors, const NVTEShape &logical_shape, + const NVTEScalingMode scaling_mode = NVTE_DELAYED_TENSOR_SCALING) + : tensor_(nvte_create_grouped_tensor(scaling_mode, num_tensors, logical_shape)) {} + + /*! \brief Constructs new GroupedTensorWrapper. + * + * Create a new TE grouped tensor with a given logical shape. + * + * \param[in] num_tensors Number of tensors in the group (must be > 0). + * \param[in] logical_shape Logical 2D shape of the grouped data. + * \param[in] scaling_mode Tensor data format. + */ + GroupedTensorWrapper(const size_t num_tensors, const std::vector &logical_shape, + const NVTEScalingMode scaling_mode = NVTE_DELAYED_TENSOR_SCALING) + : GroupedTensorWrapper(num_tensors, + nvte_make_shape(logical_shape.data(), logical_shape.size()), + scaling_mode) {} + + /*! \brief GroupedTensorWrapper destructor. */ + ~GroupedTensorWrapper() { nvte_destroy_grouped_tensor(tensor_); } + + GroupedTensorWrapper &operator=(const GroupedTensorWrapper &other) = delete; + GroupedTensorWrapper(const GroupedTensorWrapper &other) = delete; + + /*! \brief Constructs new GroupedTensorWrapper from existing GroupedTensorWrapper. */ + GroupedTensorWrapper(GroupedTensorWrapper &&other) { + tensor_ = other.tensor_; + other.tensor_ = nullptr; + } + + /*! \brief Assign the data from existing GroupedTensorWrapper. */ + GroupedTensorWrapper &operator=(GroupedTensorWrapper &&other) { + if (this == &other) return *this; + nvte_destroy_grouped_tensor(tensor_); + tensor_ = other.tensor_; + other.tensor_ = nullptr; + return *this; + } + + // Parameter setters + template + GroupedTensorWrapper &set_parameter(const NVTEGroupedTensorParam param, void *dptr, DType type, + const ShapeType &shape) noexcept { + NVTEShape nvte_shape = this->convertShape(shape); + NVTEBasicTensor data = {dptr, static_cast(type), nvte_shape}; + nvte_set_grouped_tensor_param(&tensor_, param, &data); + return *this; + } + + template + GroupedTensorWrapper &set_rowwise_data(void *dptr, DType type, const ShapeType &shape) noexcept { + return set_parameter(kNVTEGroupedRowwiseData, dptr, type, shape); + } + + template + GroupedTensorWrapper &set_columnwise_data(void *dptr, DType type, + const ShapeType &shape) noexcept { + return set_parameter(kNVTEGroupedColumnwiseData, dptr, type, shape); + } + + template + GroupedTensorWrapper &set_scale(void *dptr, DType type, const ShapeType &shape) noexcept { + return set_parameter(kNVTEGroupedScale, dptr, type, shape); + } + + template + GroupedTensorWrapper &set_amax(void *dptr, DType type, const ShapeType &shape) noexcept { + return set_parameter(kNVTEGroupedAmax, dptr, type, shape); + } + + template + GroupedTensorWrapper &set_rowwise_scale_inv(void *dptr, DType type, + const ShapeType &shape) noexcept { + return set_parameter(kNVTEGroupedRowwiseScaleInv, dptr, type, shape); + } + + template + GroupedTensorWrapper &set_columnwise_scale_inv(void *dptr, DType type, + const ShapeType &shape) noexcept { + return set_parameter(kNVTEGroupedColumnwiseScaleInv, dptr, type, shape); + } + + template + GroupedTensorWrapper &set_columnwise_amax(void *dptr, DType type, + const ShapeType &shape) noexcept { + return set_parameter(kNVTEGroupedColumnwiseAmax, dptr, type, shape); + } + + template + GroupedTensorWrapper &set_first_dims(void *dptr, DType type, const ShapeType &shape) noexcept { + return set_parameter(kNVTEGroupedFirstDims, dptr, type, shape); + } + + template + GroupedTensorWrapper &set_last_dims(void *dptr, DType type, const ShapeType &shape) noexcept { + return set_parameter(kNVTEGroupedLastDims, dptr, type, shape); + } + + template + GroupedTensorWrapper &set_tensor_offsets(void *dptr, DType type, + const ShapeType &shape) noexcept { + return set_parameter(kNVTEGroupedTensorOffsets, dptr, type, shape); + } + + // Parameter getters + NVTEBasicTensor get_parameter(const NVTEGroupedTensorParam param) const noexcept { + return nvte_get_grouped_tensor_param(tensor_, param); + } + + NVTEBasicTensor get_rowwise_data() const noexcept { + return get_parameter(kNVTEGroupedRowwiseData); + } + + NVTEBasicTensor get_columnwise_data() const noexcept { + return get_parameter(kNVTEGroupedColumnwiseData); + } + + NVTEBasicTensor get_scale() const noexcept { return get_parameter(kNVTEGroupedScale); } + + NVTEBasicTensor get_amax() const noexcept { return get_parameter(kNVTEGroupedAmax); } + + NVTEBasicTensor get_rowwise_scale_inv() const noexcept { + return get_parameter(kNVTEGroupedRowwiseScaleInv); + } + + NVTEBasicTensor get_columnwise_scale_inv() const noexcept { + return get_parameter(kNVTEGroupedColumnwiseScaleInv); + } + + NVTEBasicTensor get_columnwise_amax() const noexcept { + return get_parameter(kNVTEGroupedColumnwiseAmax); + } + + NVTEBasicTensor get_first_dims() const noexcept { return get_parameter(kNVTEGroupedFirstDims); } + + NVTEBasicTensor get_last_dims() const noexcept { return get_parameter(kNVTEGroupedLastDims); } + + NVTEBasicTensor get_tensor_offsets() const noexcept { + return get_parameter(kNVTEGroupedTensorOffsets); + } + + /*! \brief Get an underlying NVTEGroupedTensor. + * + * \return NVTEGroupedTensor held by this GroupedTensorWrapper. + */ + NVTEGroupedTensor data() const noexcept { return tensor_; } + + /*! \brief Get the number of tensors in this GroupedTensorWrapper. */ + size_t num_tensors() const noexcept { + if (tensor_ == nullptr) return 0; + return nvte_grouped_tensor_num_tensors(tensor_); + } + + /*! \brief Get the data type of this GroupedTensorWrapper. */ + DType dtype() const noexcept { + if (tensor_ == nullptr) return DType::kNumTypes; + return static_cast(nvte_grouped_tensor_type(tensor_)); + } + + /*! \brief Get a scaling mode of the grouped tensor. */ + NVTEScalingMode scaling_mode() const noexcept { + if (tensor_ == nullptr) return NVTE_DELAYED_TENSOR_SCALING; + return nvte_grouped_tensor_scaling_mode(tensor_); + } + + /*! \brief Get the logical shape of this GroupedTensorWrapper. */ + const NVTEShape logical_shape() const noexcept { + if (tensor_ == nullptr) { + return emptyShape; + } + return nvte_get_grouped_tensor_logical_shape(tensor_); + } + + static constexpr size_t defaultData = 1; + static constexpr NVTEShape defaultShape = { + {defaultData, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, 1}; + static constexpr NVTEShape emptyShape = {{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, 1}; + + private: + NVTEShape convertShape(const NVTEShape &s) { return s; } + + NVTEShape convertShape(const std::vector &s) { + return nvte_make_shape(s.data(), s.size()); + } + + /*! \brief Wrapped NVTEGroupedTensor. */ + NVTEGroupedTensor tensor_ = nullptr; +}; + +/*! \enum Float8BlockScaleTensorFormat + * \brief Data format for an FP8 block-scaled tensor + */ /*! \warning Deprecated */ enum class Float8BlockScaleTensorFormat { GEMM_READY = 0, COMPACT = 1, INVALID }; diff --git a/transformer_engine/pytorch/cpp_extensions/gemm.py b/transformer_engine/pytorch/cpp_extensions/gemm.py index 406e7075f7..5ef0ef741e 100644 --- a/transformer_engine/pytorch/cpp_extensions/gemm.py +++ b/transformer_engine/pytorch/cpp_extensions/gemm.py @@ -5,6 +5,7 @@ """Python interface for GEMM extensions""" from typing import Iterable, Optional, Tuple, Union, List +import ctypes import os import functools import torch @@ -14,6 +15,7 @@ from ..quantized_tensor import Quantizer from ..tensor.storage.float8_blockwise_tensor_storage import Float8BlockwiseQTensorStorage +from ..tensor.storage.grouped_tensor import GroupedTensor from ..tensor.utils import is_custom from ..custom_recipes.gemm import custom_gemm from ...debug.pytorch.debug_quantization import DebugQuantizer @@ -22,6 +24,7 @@ __all__ = [ "general_gemm", "general_grouped_gemm", + "general_grouped_gemm_for_grouped_tensor", ] @@ -306,3 +309,94 @@ def general_grouped_gemm( ) return out, bias, gelu_input + + +def get_grouped_gemm_setup_workspace_size(num_tensors: int) -> int: + """Return workspace size for grouped GEMM pointer setup. + Must match GroupedGemmSetupWorkspace::required_setup_size in cublaslt_grouped_gemm.cu. + """ + ptr_bytes = ctypes.sizeof(ctypes.c_void_p) + int_bytes = ctypes.sizeof(ctypes.c_int) + ptr_size = num_tensors * ptr_bytes + int_size = num_tensors * int_bytes + k_ptr_alignment = 16 + aligned_ptr_size = ((ptr_size + k_ptr_alignment - 1) // k_ptr_alignment) * k_ptr_alignment + size = 8 * aligned_ptr_size + 6 * int_size + alignment = 256 + return ((size + alignment - 1) // alignment) * alignment + + +def general_grouped_gemm_for_grouped_tensor( + A, + B, + out, + *, + layout: str = "TN", + accumulate: bool = False, + alpha: Optional[torch.Tensor] = None, + beta: Optional[torch.Tensor] = None, +) -> torch.Tensor: + """ + Grouped GEMM using GroupedTensor inputs. + + This uses nvte_grouped_gemm and supports different per-matrix shapes. + + The caller must ensure that GroupedTensor metadata is already compatible with the + underlying GEMM implementation (e.g., aligned offsets and output metadata layout). + """ + assert layout in ("TN", "NN", "NT"), f"GEMM layout {layout} not supported." + transa = layout[0] == "T" + transb = layout[1] == "T" + + num_tensors = A.num_tensors + assert A.num_tensors == B.num_tensors == out.num_tensors, ( + f"GroupedTensor num_tensors must match: A={A.num_tensors}, B={B.num_tensors}," + f" out={out.num_tensors}" + ) + + if out.data is not None: + device = out.data.device + elif out.columnwise_data is not None: + device = out.columnwise_data.device + else: + raise ValueError("Output GroupedTensor must have allocated data.") + + if alpha is None: + alpha = torch.ones(num_tensors, dtype=torch.float32, device=device) + if beta is None: + if accumulate: + beta = torch.ones(num_tensors, dtype=torch.float32, device=device) + else: + beta = torch.zeros(num_tensors, dtype=torch.float32, device=device) + + if not alpha.is_cuda or not beta.is_cuda: + raise ValueError("alpha and beta must be CUDA tensors.") + + workspace_setup = torch.empty( + get_grouped_gemm_setup_workspace_size(num_tensors), + dtype=torch.uint8, + device=device, + ) + workspace_cublas = torch.empty( + get_cublas_workspace_size_bytes(), + dtype=torch.uint8, + device=device, + ) + + sm_count = get_sm_count() + sm_count = sm_count - int(os.getenv("NVTE_EXT_MARGIN_SM", str(sm_count))) + + C = out + return tex.te_general_grouped_gemm_for_grouped_tensor( + A, + transa, + B, + transb, + C, + out, + alpha, + beta, + workspace_setup, + workspace_cublas, + sm_count, + ) diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index e0ea3d6b78..e901919831 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -150,6 +150,13 @@ std::optional> te_general_grouped_gemm( std::vector pre_gelu_out, bool grad, std::vector workspace, size_t workspaceSize, bool accumulate, bool use_split_accumulator, int math_sm_count); +py::object te_general_grouped_gemm_for_grouped_tensor(py::handle A, bool transa, py::handle B, + bool transb, py::object C, py::handle D, + at::Tensor alpha, at::Tensor beta, + at::Tensor workspace_setup, + at::Tensor workspace_cublas, + int math_sm_count); + /*************************************************************************************************** * Transpose **************************************************************************************************/ diff --git a/transformer_engine/pytorch/csrc/extensions/gemm.cpp b/transformer_engine/pytorch/csrc/extensions/gemm.cpp index d75b0f14c7..d5a8ff5489 100644 --- a/transformer_engine/pytorch/csrc/extensions/gemm.cpp +++ b/transformer_engine/pytorch/csrc/extensions/gemm.cpp @@ -570,4 +570,72 @@ std::optional> te_general_grouped_gemm( return bias; } +py::object te_general_grouped_gemm_for_grouped_tensor(py::handle A, bool transa, py::handle B, + bool transb, py::object C, py::handle D, + at::Tensor alpha, at::Tensor beta, + at::Tensor workspace_setup, + at::Tensor workspace_cublas, + int math_sm_count) { + using namespace transformer_engine::pytorch::detail; + + init_extension(); + + // Ensure that cublasLt handle is created on the correct device, + // overriding torch.cuda.set_device calls from user side. + // Assumes all tensors passed are on the same device. + at::cuda::CUDAGuard device_guard(workspace_cublas.device()); + + auto grouped_A = GroupedTensorFromPyTorchGroupedTensor(A); + auto grouped_B = GroupedTensorFromPyTorchGroupedTensor(B); + auto grouped_D = GroupedTensorFromPyTorchGroupedTensor(D); + + std::optional grouped_C = std::nullopt; + if (!C.is_none()) { + grouped_C = GroupedTensorFromPyTorchGroupedTensor(C); + } + + const size_t num_tensors = grouped_A.num_tensors(); + NVTE_CHECK(num_tensors > 0, "Grouped GEMM requires non-empty inputs."); + NVTE_CHECK(grouped_B.num_tensors() == num_tensors, + "Grouped GEMM requires A and B to have the same num_tensors."); + NVTE_CHECK(grouped_D.num_tensors() == num_tensors, + "Grouped GEMM requires D to have the same num_tensors as inputs."); + if (grouped_C.has_value()) { + NVTE_CHECK(grouped_C->num_tensors() == num_tensors, + "Grouped GEMM requires C to have the same num_tensors as inputs."); + } + + NVTE_CHECK(alpha.numel() == static_cast(num_tensors), + "Grouped GEMM expects alpha to have num_tensors elements."); + NVTE_CHECK(beta.numel() == static_cast(num_tensors), + "Grouped GEMM expects beta to have num_tensors elements."); + + auto te_alpha = makeTransformerEngineTensor(alpha); + auto te_beta = makeTransformerEngineTensor(beta); + + auto te_workspace_setup = makeTransformerEngineTensor( + workspace_setup.data_ptr(), std::vector{static_cast(workspace_setup.numel())}, + DType::kByte); + auto te_workspace_cublas = makeTransformerEngineTensor( + workspace_cublas.data_ptr(), + std::vector{static_cast(workspace_cublas.numel())}, DType::kByte); + + std::optional config; + if (math_sm_count > 0) { + config.emplace(); + config->set_sm_count(math_sm_count); + } + + NVTE_SCOPED_GIL_RELEASE({ + nvte_grouped_gemm(grouped_A.data(), transa, grouped_B.data(), transb, + grouped_C.has_value() ? grouped_C->data() : nullptr, grouped_D.data(), + te_alpha.data(), te_beta.data(), te_workspace_setup.data(), + te_workspace_cublas.data(), + config.has_value() ? static_cast(*config) : nullptr, + at::cuda::getCurrentCUDAStream()); + }); + + return py::reinterpret_borrow(D); +} + } // namespace transformer_engine::pytorch diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index 1e907d9bc0..83ae45299e 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -251,6 +251,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { py::arg("quantizer_list"), py::arg("disable_bulk_allocation") = false); m.def("te_general_grouped_gemm", &transformer_engine::pytorch::te_general_grouped_gemm, "Grouped GEMM"); + m.def("te_general_grouped_gemm_for_grouped_tensor", + &transformer_engine::pytorch::te_general_grouped_gemm_for_grouped_tensor, + "Grouped GEMM for GroupedTensor"); m.def("fp8_transpose", &transformer_engine::pytorch::fp8_transpose, "Transpose with FP8 I/O", py::arg("input"), py::arg("dtype"), py::kw_only(), py::arg("out"), py::call_guard()); diff --git a/transformer_engine/pytorch/csrc/pybind.h b/transformer_engine/pytorch/csrc/pybind.h index 25ffef0588..9541409c0c 100644 --- a/transformer_engine/pytorch/csrc/pybind.h +++ b/transformer_engine/pytorch/csrc/pybind.h @@ -95,6 +95,8 @@ TensorWrapper NVTETensorFromFloat8BlockwiseQTensor(py::handle tensor, TensorWrapper NVTETensorFromNVFP4Tensor(py::handle tensor, Quantizer *quantizer); +GroupedTensorWrapper GroupedTensorFromPyTorchGroupedTensor(py::handle tensor); + inline bool IsFloatingPointType(at::ScalarType type) { return type == at::kFloat || type == at::kHalf || type == at::kBFloat16; } diff --git a/transformer_engine/pytorch/csrc/type_converters.cpp b/transformer_engine/pytorch/csrc/type_converters.cpp index 3f998bb66f..07961d85d4 100644 --- a/transformer_engine/pytorch/csrc/type_converters.cpp +++ b/transformer_engine/pytorch/csrc/type_converters.cpp @@ -170,6 +170,119 @@ TensorWrapper NVTETensorFromNVFP4Tensor(py::handle tensor, Quantizer *quantizer) return ret; } +NVTEScalingMode ScalingModeFromQuantizer(py::handle quantizer) { + auto *quantizer_ptr = quantizer.ptr(); + if (IsMXFP8Quantizers(quantizer_ptr)) { + return NVTE_MXFP8_1D_SCALING; + } + if (IsNVFP4Quantizers(quantizer_ptr)) { + return NVTE_NVFP4_1D_SCALING; + } + if (IsFloat8BlockwiseQuantizers(quantizer_ptr)) { + const int block_scaling_dim = quantizer.attr("block_scaling_dim").cast(); + return (block_scaling_dim == 2) ? NVTE_BLOCK_SCALING_2D : NVTE_BLOCK_SCALING_1D; + } + return NVTE_DELAYED_TENSOR_SCALING; +} + +DType GetTransformerEngineDTypeForScaleInv(py::handle quantizer, at::Tensor scale_inv) { + auto *quantizer_ptr = quantizer.ptr(); + if (IsMXFP8Quantizers(quantizer_ptr)) { + return DType::kFloat8E8M0; + } + if (IsFloat8BlockwiseQuantizers(quantizer_ptr)) { + return DType::kFloat32; + } + if (IsNVFP4Quantizers(quantizer_ptr)) { + return DType::kFloat8E4M3; + } + return GetTransformerEngineDType(scale_inv.scalar_type()); +} + +GroupedTensorWrapper GroupedTensorFromPyTorchGroupedTensor(py::handle tensor) { + // Returns a GroupedTensorWrapper from a PyTorch GroupedTensor. + const auto num_tensors = tensor.attr("num_tensors").cast(); + const auto logical_shape = tensor.attr("logical_shape").cast>(); + py::handle quantizer = py::none(); + DType quantizer_dtype = DType::kNumTypes; + NVTEScalingMode scaling_mode = NVTE_DELAYED_TENSOR_SCALING; + if (!tensor.attr("quantizer").is_none()) { + quantizer = tensor.attr("quantizer").cast(); + scaling_mode = ScalingModeFromQuantizer(quantizer); + quantizer_dtype = quantizer.attr("dtype").cast(); + } + auto ret = GroupedTensorWrapper(num_tensors, logical_shape, scaling_mode); + + // Rowwise data + if (!tensor.attr("data").is_none()) { + const auto &data = tensor.attr("data").cast(); + DType data_dtype = + quantizer.is_none() ? GetTransformerEngineDType(data.scalar_type()) : quantizer_dtype; + ret.set_rowwise_data(data.data_ptr(), data_dtype, getTensorShape(data)); + } + + // Columnwise data + if (!tensor.attr("columnwise_data").is_none()) { + const auto &data = tensor.attr("columnwise_data").cast(); + DType data_dtype = + quantizer.is_none() ? GetTransformerEngineDType(data.scalar_type()) : quantizer_dtype; + ret.set_columnwise_data(data.data_ptr(), data_dtype, getTensorShape(data)); + } + + // Scale + if (!tensor.attr("scale").is_none()) { + const auto &scale = tensor.attr("scale").cast(); + ret.set_scale(scale.data_ptr(), GetTransformerEngineDType(scale.scalar_type()), + getTensorShape(scale)); + } + + // Amax + if (!tensor.attr("amax").is_none()) { + const auto &amax = tensor.attr("amax").cast(); + ret.set_amax(amax.data_ptr(), GetTransformerEngineDType(amax.scalar_type()), + getTensorShape(amax)); + } + if (!tensor.attr("columnwise_amax").is_none()) { + const auto &amax = tensor.attr("columnwise_amax").cast(); + ret.set_columnwise_amax(amax.data_ptr(), GetTransformerEngineDType(amax.scalar_type()), + getTensorShape(amax)); + } + + // Scale inverse + if (!tensor.attr("scale_inv").is_none()) { + const auto &scale_inv = tensor.attr("scale_inv").cast(); + ret.set_rowwise_scale_inv(scale_inv.data_ptr(), + GetTransformerEngineDTypeForScaleInv(quantizer, scale_inv), + getTensorShape(scale_inv)); + } + if (!tensor.attr("columnwise_scale_inv").is_none()) { + const auto &scale_inv = tensor.attr("columnwise_scale_inv").cast(); + ret.set_columnwise_scale_inv(scale_inv.data_ptr(), + GetTransformerEngineDTypeForScaleInv(quantizer, scale_inv), + getTensorShape(scale_inv)); + } + + // Shape metadata + if (!tensor.attr("first_dims").is_none()) { + const auto &first_dims = tensor.attr("first_dims").cast(); + ret.set_first_dims(first_dims.data_ptr(), GetTransformerEngineDType(first_dims.scalar_type()), + getTensorShape(first_dims)); + } + if (!tensor.attr("last_dims").is_none()) { + const auto &last_dims = tensor.attr("last_dims").cast(); + ret.set_last_dims(last_dims.data_ptr(), GetTransformerEngineDType(last_dims.scalar_type()), + getTensorShape(last_dims)); + } + if (!tensor.attr("tensor_offsets").is_none()) { + const auto &tensor_offsets = tensor.attr("tensor_offsets").cast(); + ret.set_tensor_offsets(tensor_offsets.data_ptr(), + GetTransformerEngineDType(tensor_offsets.scalar_type()), + getTensorShape(tensor_offsets)); + } + + return ret; +} + } // namespace detail } // namespace transformer_engine::pytorch diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index b6596bc2e9..bdf8dc6388 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -3,8 +3,12 @@ # See LICENSE for license information. """GroupedLinear API""" -from typing import Union, Optional, Callable, Tuple, List +from transformer_engine.pytorch.quantized_tensor import QuantizedTensorStorage +from torch._tensor import Tensor +from typing import Any, Union, Optional, Callable, Tuple, List from itertools import chain +from torch.distributed.tensor import DTensor + import warnings import functools @@ -39,6 +43,7 @@ ) from ..cpp_extensions import ( general_grouped_gemm, + general_grouped_gemm_for_grouped_tensor, ) from ..constants import GemmParallelModes, dist_group_type from ..jit import no_torch_dynamo @@ -57,6 +62,48 @@ __all__ = ["GroupedLinear"] +def _clone_grouped_tensor_with_data( + grouped_tensor: GroupedTensor, data: torch.Tensor, dtype: torch.dtype +) -> GroupedTensor: + return GroupedTensor( + num_tensors=grouped_tensor.num_tensors, + shape=grouped_tensor.shape, + quantizer=grouped_tensor.quantizer, + dtype=dtype, + data=data, + columnwise_data=grouped_tensor.columnwise_data, + scale_inv=grouped_tensor.scale_inv, + columnwise_scale_inv=grouped_tensor.columnwise_scale_inv, + amax=grouped_tensor.amax, + columnwise_amax=grouped_tensor.columnwise_amax, + scale=grouped_tensor.scale, + first_dims=grouped_tensor.first_dims, + last_dims=grouped_tensor.last_dims, + tensor_offsets=grouped_tensor.tensor_offsets, + offsets=grouped_tensor.offsets, + scale_inv_offsets=grouped_tensor.scale_inv_offsets, + columnwise_scale_inv_offsets=grouped_tensor.columnwise_scale_inv_offsets, + logical_shape=grouped_tensor.logical_shape, + ) + + +def _make_grouped_tensor_for_m_splits(data: torch.Tensor, m_splits: torch.Tensor) -> GroupedTensor: + # Use data.shape[0] to avoid first_dims.sum().item() D2H copy (breaks CUDA graph) + logical_first_dim = data.shape[0] + grouped = GroupedTensor.make_grouped_tensor( + num_tensors=int(m_splits.numel()), + first_dims=m_splits, + last_dims=None, + logical_first_dim=logical_first_dim, + logical_last_dim=data.shape[-1], + quantizer=None, + device=data.device, + dtype=data.dtype, + ) + grouped.data = data.contiguous().view(-1) + return grouped + + class _GroupedLinear(torch.autograd.Function): """GroupedLinear semi-top level module Calls custom cuda extensions. @@ -76,6 +123,7 @@ def forward( # to reduce CPU overhead due to pytorch arg checking. ( m_splits, + m_splits_is_tensor, use_bias, is_first_microbatch, fp8, @@ -97,10 +145,11 @@ def forward( save_original_input, debug, ) = non_tensor_args - - num_gemms = len(m_splits) - weights = weights_and_biases[:num_gemms] - biases = weights_and_biases[num_gemms:] + num_weight_params = module.num_weight_params + num_gemms = int(m_splits.numel()) if m_splits_is_tensor else len(m_splits) + logical_first_dim = inp.shape[0] if m_splits_is_tensor else sum(m_splits) + weights = weights_and_biases[:num_weight_params] + biases = weights_and_biases[num_weight_params:] device = inp.device weight_requires_grad = weights[0].requires_grad @@ -133,9 +182,11 @@ def forward( if output_quantizers[0] is not None: for output_quantizer in output_quantizers: output_quantizer.set_usage(rowwise=True, columnwise=False) + no_quantization = not fp8 and weight_quantizers[0] is None # Initialize input tensors - in_features = weights[0].size(-1) + in_features = module.in_features + out_features = module.out_features if inp.size(-1) != in_features: raise ValueError( f"Input tensor (shape={tuple(inp.size())}) is not compatible with " @@ -143,6 +194,14 @@ def forward( ) inp_view = inp.reshape(-1, in_features) inputmats: list + inp_view_cast = None + if m_splits_is_tensor and not no_quantization: + # TODO: Support this path. + raise ValueError( + "GroupedGEMM with grouped tensor path with quantization is not supported yet." + ) + grouped_tensor_path = no_quantization and m_splits_is_tensor + if fp8 and not debug: # Disable bulk allocation when CPU offloading is active: offloading skips small # tensors (like scales), but bulk allocation shares storage across all tensors, @@ -158,7 +217,10 @@ def forward( inp_view, input_quantizers, m_splits, activation_dtype ) else: - inputmats = torch.split(cast_if_needed(inp_view, activation_dtype), m_splits) + inp_view_cast = cast_if_needed(inp_view, activation_dtype) + inputmats = ( + [inp_view_cast] if grouped_tensor_path else torch.split(inp_view_cast, m_splits) + ) if cpu_offloading: start_offload(*inputmats) @@ -169,7 +231,7 @@ def forward( # FP8 cast to workspace buffer weights_fp8 = [] update_workspace = is_first_microbatch is None or is_first_microbatch - for i in range(num_gemms): + for i in range(num_weight_params): weight_fp8 = module.get_weight_workspace( tensor=weights[i], quantizer=weight_quantizers[i], @@ -190,7 +252,7 @@ def forward( biases = [cast_if_needed(bias, bias_dtype) for bias in biases] if use_bias else biases # Initialize output tensor out = torch.empty( - [sum(m_splits), weights_fp8[0].size(0)], + [logical_first_dim, out_features], dtype=activation_dtype, device=device, ) @@ -202,19 +264,35 @@ def forward( if hasattr(recipe, "fp8_gemm_fprop"): use_split_accumulator = recipe.fp8_gemm_fprop.use_split_accumulator - # Perform GEMM - general_grouped_gemm( - weights_fp8, - inputmats, - [out], - output_quantizers, - activation_dtype, - single_output=True, - m_splits=m_splits, - bias=biases, - use_bias=use_bias, - use_split_accumulator=use_split_accumulator, - ) + if grouped_tensor_path: + grouped_weight = _clone_grouped_tensor_with_data( + module.grouped_weight_storage, + cast_if_needed(module.grouped_weight_storage.data, activation_dtype), + activation_dtype, + ) + grouped_input = _make_grouped_tensor_for_m_splits(inputmats[0], m_splits) + grouped_out = _make_grouped_tensor_for_m_splits(out, m_splits) + general_grouped_gemm_for_grouped_tensor( + grouped_weight, + grouped_input, + grouped_out, + layout="TN", + accumulate=False, + ) + else: + # Perform GEMM + general_grouped_gemm( + weights_fp8, + inputmats, + [out], + output_quantizers, + activation_dtype, + single_output=True, + m_splits=m_splits, + bias=biases, + use_bias=use_bias, + use_split_accumulator=use_split_accumulator, + ) if fp8_calibration: for i in range(num_gemms): @@ -229,7 +307,10 @@ def forward( if is_grad_enabled: ctx.weight_quantizers = weight_quantizers - ctx.weights_shape_1 = weights[0].shape[1] + if module.single_weight: + ctx.weights_shape_1 = module.in_features + else: + ctx.weights_shape_1 = weights[0].shape[1] # TODO: update after #1638 is merged. # pylint: disable=fixme if weight_requires_grad: @@ -264,7 +345,6 @@ def forward( ) ctx.save_for_backward(*tensors_to_save) ctx.tensor_objects = tensor_objects - ctx.grad_input_quantizers = grad_input_quantizers ctx.grad_output_quantizers = grad_output_quantizers ctx.grad_weight_quantizers = grad_weight_quantizers @@ -276,17 +356,22 @@ def forward( # the main_grad buffer lazily before backprop if hasattr(weights[0], "__fsdp_param__"): # MCore FSDP creates main_grad lazily before backward - ctx.main_grad_funcs = [weights[i].get_main_grad for i in range(num_gemms)] + ctx.main_grad_funcs = [ + weights[i].get_main_grad for i in range(num_weight_params) + ] else: ctx.main_grad_funcs = [ - lambda j=i: weights[j].main_grad for i in range(num_gemms) + lambda j=i: weights[j].main_grad for i in range(num_weight_params) ] else: - ctx.main_grad_funcs = [lambda: None for i in range(num_gemms)] + ctx.main_grad_funcs = [lambda: None for i in range(num_weight_params)] ctx.device = device ctx.output_quantizers = output_quantizers ctx.m_splits = m_splits + ctx.logical_first_dim = logical_first_dim + ctx.grouped_tensor_path = grouped_tensor_path ctx.num_gemms = num_gemms + ctx.num_weight_params = num_weight_params ctx.activation_dtype = activation_dtype ctx.fp8 = fp8 ctx.fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8 else None @@ -295,6 +380,8 @@ def forward( ctx.is_first_microbatch = is_first_microbatch ctx.use_bias = use_bias ctx.sequence_parallel = sequence_parallel + ctx.in_features = module.in_features + ctx.out_features = module.out_features ctx.inp_shape = inp.shape ctx.requires_dgrad = inp.requires_grad ctx.reduce_and_update_bwd_fp8_tensors = False @@ -307,7 +394,10 @@ def forward( ctx.debug = debug ctx.save_original_input = save_original_input ctx.input_quantizers = input_quantizers - + ctx.single_weight = module.single_weight + ctx.grouped_weight_storage = ( + module.grouped_weight_storage if grouped_tensor_path else None + ) # [*, in_features] -> [*, out_features] except first dimension changes for SP return out.view(-1, *inp.shape[1:-1], out.shape[-1]) @@ -315,8 +405,9 @@ def forward( def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]: # pylint: disable=missing-function-docstring with get_nvtx_range_context("_GroupedLinear_backward"): + m_splits = ctx.m_splits saved_tensors = restore_from_saved(ctx.tensor_objects, ctx.saved_tensors) - N = ctx.num_gemms + N = ctx.num_weight_params inputmats = saved_tensors[:N] weights = saved_tensors[N : 2 * N] origin_weights = saved_tensors[2 * N : 3 * N] @@ -366,7 +457,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ) elif ctx.debug: grad_output_mats = torch.split(grad_output_view, ctx.m_splits) - for i in range(ctx.num_gemms): + for i in range(ctx.num_weight_params): grad_biases[i] = grad_output_mats[i].sum(dim=0) grad_output = DebugQuantizer.multi_tensor_quantize( grad_output_view, @@ -377,10 +468,15 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], else: # Only split grad output. Grad bias is fused with # wgrad GEMM. - grad_output = torch.split( - cast_if_needed(grad_output_view, ctx.activation_dtype), - ctx.m_splits, - ) + if ctx.grouped_tensor_path: + out = cast_if_needed(grad_output_view, ctx.activation_dtype) + grad_output = [out] + grouped_grad_output = _make_grouped_tensor_for_m_splits(out, m_splits) + else: + grad_output = torch.split( + cast_if_needed(grad_output_view, ctx.activation_dtype), + ctx.m_splits, + ) if ctx.is_first_microbatch is not None: accumulate_wgrad_into_param_main_grad = ( @@ -398,27 +494,39 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], recipe.fp8_gemm_dgrad.use_split_accumulator ) dgrad = torch.empty( - (sum(ctx.m_splits), ctx.weights_shape_1), + ctx.inp_shape, dtype=ctx.activation_dtype, device=ctx.device, ) + # Make sure weights are available in column-wise format # for dgrad computation. for weight in weights: if isinstance(weight, QuantizedTensorStorage): weight.update_usage(columnwise_usage=True) - general_grouped_gemm( - weights, - grad_output, - [dgrad], - ctx.grad_input_quantizers, - ctx.activation_dtype, - single_output=True, - layout="NN", - m_splits=ctx.m_splits, - grad=True, - use_split_accumulator=dgrad_gemm_use_split_accumulator, - ) + if ctx.grouped_tensor_path: + grouped_weight = ctx.grouped_weight_storage + grouped_dgrad = _make_grouped_tensor_for_m_splits(dgrad, m_splits) + general_grouped_gemm_for_grouped_tensor( + grouped_weight, + grouped_grad_output, + grouped_dgrad, + layout="NN", + accumulate=False, + ) + else: + general_grouped_gemm( + weights, + grad_output, + [dgrad], + ctx.grad_input_quantizers, + ctx.activation_dtype, + single_output=True, + layout="NN", + m_splits=ctx.m_splits, + grad=True, + use_split_accumulator=dgrad_gemm_use_split_accumulator, + ) if ctx.weights_requires_grad: wgrad_gemm_use_split_accumulator = _2X_ACC_WGRAD @@ -428,7 +536,26 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], wgrad_gemm_use_split_accumulator = ( recipe.fp8_gemm_wgrad.use_split_accumulator ) - if ctx.fuse_wgrad_accumulation: + grouped_wgrad = None + if ctx.grouped_tensor_path and ctx.fuse_wgrad_accumulation: + raise NotImplementedError( + "Fused wgrad accumulation is not supported with grouped tensor path." + ) + if ctx.grouped_tensor_path: + # Wgrad GEMM writes one output per group; use num_gemms (not num_weight_params). + num_wgrad_tensors = ctx.num_gemms + grouped_wgrad = GroupedTensor.make_grouped_tensor_with_shapes( + num_tensors=num_wgrad_tensors, + shape=[(ctx.out_features, ctx.in_features)] * num_wgrad_tensors, + quantizer=None, + dtype=ctx.activation_dtype, + device=ctx.device, + ) + if ctx.single_weight: + wgrad_list = [grouped_wgrad.data.view(-1)] + else: + wgrad_list = grouped_wgrad.split_into_quantized_tensors() + elif ctx.fuse_wgrad_accumulation: wgrad_list = main_grads else: wgrad_list = [ @@ -460,32 +587,66 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ctx.activation_dtype, ) else: - inputmats = torch.split( - cast_if_needed(inp_view, ctx.activation_dtype), ctx.m_splits + if ctx.grouped_tensor_path: + inputmats = [cast_if_needed(inp_view, ctx.activation_dtype)] + else: + inputmats = torch.split( + cast_if_needed(inp_view, ctx.activation_dtype), ctx.m_splits + ) + + if ctx.grouped_tensor_path: + + def grouped_gemm_wgrad_grouped_tensor(inputmat, grad_output, grouped_wgrad): + grouped_input = _make_grouped_tensor_for_m_splits(inputmat, ctx.m_splits) + grouped_grad_output = _make_grouped_tensor_for_m_splits( + grad_output, ctx.m_splits ) - grouped_gemm_wgrad = functools.partial( - general_grouped_gemm, - quantization_params=ctx.grad_weight_quantizers, - out_dtype=ctx.activation_dtype, - layout="NT", - grad=True, - m_splits=ctx.m_splits, - use_bias=ctx.use_bias if grad_biases[0] is None else None, - bias=biases, - use_split_accumulator=wgrad_gemm_use_split_accumulator, - accumulate=( - accumulate_wgrad_into_param_main_grad - if not getattr(weights[0], "overwrite_main_grad", False) - else False - ), - ) + # dW = grad_output^T @ input -> (out_features, m) @ (m, in_features). + # Row-wise: A (m, n) -> cuBLAS (n, m); use A=grad_output, B=input. + # Layout NT: op(A)=(n, m), op(B)^T=(m, k) -> D = (n, k). + general_grouped_gemm_for_grouped_tensor( + grouped_grad_output, + grouped_input, + grouped_wgrad, + layout="NT", + accumulate=( + accumulate_wgrad_into_param_main_grad + if not getattr(weights[0], "overwrite_main_grad", False) + else False + ), + ) + return None, [None] * ctx.num_weight_params, None + + grouped_gemm_wgrad = grouped_gemm_wgrad_grouped_tensor + else: + grouped_gemm_wgrad = functools.partial( + general_grouped_gemm, + quantization_params=ctx.grad_weight_quantizers, + out_dtype=ctx.activation_dtype, + layout="NT", + grad=True, + m_splits=ctx.m_splits, + use_bias=ctx.use_bias if grad_biases[0] is None else None, + bias=biases, + use_split_accumulator=wgrad_gemm_use_split_accumulator, + accumulate=( + accumulate_wgrad_into_param_main_grad + if not getattr(weights[0], "overwrite_main_grad", False) + else False + ), + ) # WGRAD if ctx.wgrad_store is not None and ctx.wgrad_store.delay_wgrad_compute(): ctx.wgrad_store.put([inputmats, grad_output, wgrad_list], grouped_gemm_wgrad) + elif ctx.grouped_tensor_path: + # Pass 2D view so _make_grouped_tensor_for_m_splits gets correct logical_last_dim + grad_output_2d = grad_output[0].view(ctx.logical_first_dim, ctx.out_features) + # wgrad_list shares the same memory with grouped_wgrad + grouped_gemm_wgrad(inputmats[0], grad_output_2d, grouped_wgrad) else: _, grad_biases_, _ = grouped_gemm_wgrad(inputmats, grad_output, wgrad_list) - for i in range(ctx.num_gemms): + for i in range(ctx.num_weight_params): if grad_biases[i] is None: grad_biases[i] = grad_biases_[i] del grad_biases_ @@ -522,14 +683,14 @@ def handle_custom_ddp_from_mcore(weight, wgrad): for weight, wgrad in zip(origin_weights, wgrad_list) ] else: - wgrad_list = [None] * ctx.num_gemms + wgrad_list = [None] * (ctx.num_weight_params) if not ctx.use_bias or ( ctx.wgrad_store is not None and ctx.wgrad_store.delay_wgrad_compute() and not ctx.fp8 ): - grad_biases = [None] * ctx.num_gemms + grad_biases = [None] * (ctx.num_weight_params) if ctx.reduce_and_update_bwd_fp8_tensors: FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=False) @@ -624,6 +785,7 @@ def __init__( delay_wgrad_compute: bool = False, save_original_input: bool = False, name: Optional[str] = None, + single_weight: bool = False, ) -> None: super().__init__(name) @@ -632,6 +794,9 @@ def __init__( self.in_features = in_features self.out_features = out_features self.fuse_wgrad_accumulation = fuse_wgrad_accumulation + if single_weight: + bias = False + return_bias = False self.use_bias = bias self.return_bias = return_bias self.apply_bias = bias and not return_bias @@ -639,6 +804,7 @@ def __init__( self.ub_overlap_ag = ub_overlap_ag self.ub_name = ub_name self.save_original_input = save_original_input + self.single_weight = single_weight assert ( not ub_overlap_rs and not ub_overlap_ag ), "GroupedLinear doesn't support Userbuffer overlap." @@ -687,14 +853,30 @@ def __init__( self.sequence_parallel = (self.tp_size > 1) and sequence_parallel - for i in range(self.num_gemms): - # Construct weight parameter + if self.single_weight and self.primary_weights_in_fp8: + raise ValueError("Single weight is only supported for High precision weights.") + + if self.single_weight: + shape_weight = [(self.out_features * self.num_gemms * self.in_features,)] + shape_bias = [(self.out_features * self.num_gemms,)] + param_names = ["weight0", "bias0"] + self.num_weight_params = 1 + num_tensors = 1 + else: + shape_weight = [(self.out_features, self.in_features) for _ in range(self.num_gemms)] + shape_bias = [self.out_features for _ in range(self.num_gemms)] + num_tensors = self.num_gemms + param_names = [f"weight{i}" for i in range(self.num_gemms)] + [ + f"bias{i}" for i in range(self.num_gemms) + ] + self.num_weight_params = self.num_gemms + + for i in range(num_tensors): self.register_parameter( f"weight{i}", torch.nn.Parameter( torch.empty( - self.out_features, - self.in_features, + shape_weight[i], device=device, dtype=self.params_dtype, ), @@ -710,7 +892,7 @@ def __init__( f"bias{i}", torch.nn.Parameter( torch.empty( - self.out_features, + shape_bias[i], device=device, dtype=self.params_dtype, ), @@ -729,9 +911,8 @@ def __init__( if self.wgrad_store.delay_wgrad_compute(): for name, param in self.named_parameters(): - for i in range(self.num_gemms): - if name in (f"weight{i}", f"bias{i}"): - param.skip_backward_post_hook = True + if name in param_names: + param.skip_backward_post_hook = True def set_meta_tensor(self, fwd: bool, recipe: Recipe) -> None: """Init scales and amaxes for fwd | bwd.""" @@ -750,6 +931,17 @@ def make_grouped_weights(self, defer_init=False) -> None: if defer_init: return + if self.single_weight: + weight = getattr(self, "weight0") + logical_shape = (self.num_gemms * self.out_features, self.in_features) + self.grouped_weight_storage = GroupedTensor( + num_tensors=self.num_gemms, + shape=[(self.out_features, self.in_features) for _ in range(self.num_gemms)], + quantizer=None, + dtype=self.params_dtype, + data=weight, + logical_shape=logical_shape, + ) weight_quantizers = self._get_weight_quantizers() recipe = ( weight_quantizers[0]._get_compatible_recipe() @@ -770,7 +962,7 @@ def make_grouped_weights(self, defer_init=False) -> None: dtype=self.params_dtype, device=weights[0].device, ) - + self.grouped_weight_storage = grouped_weights # Copy existing params into storage. with torch.no_grad(): for i in range(self.num_gemms): @@ -800,7 +992,7 @@ def set_tensor_parallel_attributes(self, defer_init=False) -> None: if not defer_init: # Set parallelism attributes for linear weights - for i in range(self.num_gemms): + for i in range(self.num_weight_params): set_tensor_model_parallel_attributes( tensor=getattr(self, f"weight{i}"), is_parallel=True, @@ -810,12 +1002,10 @@ def set_tensor_parallel_attributes(self, defer_init=False) -> None: # Set parallelism attributes for linear biases if self.use_bias: - for i in range(self.num_gemms): + for i in range(self.num_weight_params): if self.parallel_mode == "row": setattr( - getattr(self, f"bias{i}"), - "sequence_parallel", - self.sequence_parallel, + getattr(self, f"bias{i}"), "sequence_parallel", self.sequence_parallel ) elif self.parallel_mode == "column": set_tensor_model_parallel_attributes(getattr(self, f"bias{i}"), True, 0, 1) @@ -824,7 +1014,7 @@ def set_tensor_parallel_attributes(self, defer_init=False) -> None: def forward( self, inp: torch.Tensor, - m_splits: List[int], + m_splits: Union[List[int], torch.Tensor], is_first_microbatch: Optional[bool] = None, ) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]: """ @@ -834,8 +1024,8 @@ def forward( ---------- inp : torch.Tensor Input tensor. - m_splits : List[int] - List of integers representing the split of the input tensor. + m_splits : List[int] | torch.Tensor + List of integers or a device tensor representing the split of the input tensor. is_first_microbatch : {True, False, None}, default = None During training using either gradient accumulation or pipeline parallelism a minibatch of data is further split @@ -851,18 +1041,18 @@ def forward( produced) """ debug = self.is_debug_iter() - assert not isinstance( inp, QuantizedTensorStorage ), "GroupedLinear doesn't support input tensor in FP8." - assert len(m_splits) == self.num_gemms, "Number of splits should match number of GEMMs." - + m_splits_is_tensor = torch.is_tensor(m_splits) + num_splits = m_splits.numel() if m_splits_is_tensor else len(m_splits) + assert num_splits == self.num_gemms, "Number of splits should match number of GEMMs." is_grad_enabled = torch.is_grad_enabled() inp = self.prepare_forward(inp, num_gemms=self.num_gemms) try: weight_tensors = self._get_weight_tensors() - bias_tensors = [getattr(self, f"bias{i}") for i in range(self.num_gemms)] + bias_tensors = [getattr(self, f"bias{i}") for i in range(self.num_weight_params)] quantizers = self._get_quantizers() if not debug else self._get_debug_quantizers() @@ -892,6 +1082,7 @@ def forward( non_tensor_args = ( m_splits, + m_splits_is_tensor, self.apply_bias, is_first_microbatch, self.fp8, @@ -935,10 +1126,10 @@ def backward_dw(self): weight_params = [getattr(self, f"weight{i}") for i in range(self.num_gemms)] bias_params = [getattr(self, f"bias{i}") for i in range(self.num_gemms)] if not self.fuse_wgrad_accumulation: - for i in range(self.num_gemms): + for i in range(self.num_weight_params): weight_params[i].grad = wgrad_list[i].to(weight_params[i].dtype) if self.use_bias: - for i in range(self.num_gemms): + for i in range(self.num_weight_params): if bias_params[i].grad is None: bias_params[i].grad = grad_biases_[i].to(bias_params[i].dtype) del grad_biases_ @@ -982,7 +1173,7 @@ def _customize_quantizers_float8_current_scaling(self, fwd: bool, recipe: Recipe def _get_weight_tensors(self) -> List[Union[torch.Tensor, QuantizedTensorStorage]]: """Get the weight tensors of the module.""" - weight_tensors = [getattr(self, f"weight{i}") for i in range(self.num_gemms)] + weight_tensors = [getattr(self, f"weight{i}") for i in range(self.num_weight_params)] if not self.fp8 and any(isinstance(w, QuantizedTensorStorage) for w in weight_tensors): warnings.warn( "You are using quantized weights without quantized compute. " @@ -1002,9 +1193,9 @@ def _get_weight_quantizers(self) -> List[Quantizer]: self.quantizers["scaling_fwd"][ self._offsets["weight"] + i * self._num_fp8_tensors_per_gemm["fwd"] ] - for i in range(self.num_gemms) + for i in range(self.num_weight_params) ] - for i in range(self.num_gemms): + for i in range(self.num_weight_params): weight_quantizers[i].internal = not self.primary_weights_in_fp8 return weight_quantizers