From 3e7859cc9ac3788b78dfe6b0c72a8d353ca652bb Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Fri, 6 Feb 2026 05:23:51 +0000 Subject: [PATCH 1/5] PyTorch-Python GroupedTensor Signed-off-by: Kirthi Shankar Sivamani --- tests/pytorch/test_grouped_tensor.py | 441 ++++++++ tests/pytorch/test_sanity.py | 131 ++- transformer_engine/common/recipe/__init__.py | 35 +- .../pytorch/module/grouped_linear.py | 77 +- .../pytorch/tensor/float8_tensor.py | 19 +- .../pytorch/tensor/mxfp8_tensor.py | 45 +- .../pytorch/tensor/nvfp4_tensor.py | 5 +- .../pytorch/tensor/storage/__init__.py | 1 + .../pytorch/tensor/storage/grouped_tensor.py | 942 ++++++++++++++++++ 9 files changed, 1666 insertions(+), 30 deletions(-) create mode 100644 tests/pytorch/test_grouped_tensor.py create mode 100644 transformer_engine/pytorch/tensor/storage/grouped_tensor.py diff --git a/tests/pytorch/test_grouped_tensor.py b/tests/pytorch/test_grouped_tensor.py new file mode 100644 index 0000000000..c9f1c024c8 --- /dev/null +++ b/tests/pytorch/test_grouped_tensor.py @@ -0,0 +1,441 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Tests for GroupedTensor class""" + +from typing import List, Tuple +import pytest +import torch +import transformer_engine.pytorch as te +from transformer_engine.pytorch.tensor.storage.grouped_tensor import GroupedTensor +from transformer_engine.pytorch import ( + Quantizer, + Float8Quantizer, + Float8CurrentScalingQuantizer, + Float8BlockQuantizer, + MXFP8Quantizer, + NVFP4Quantizer, +) +from transformer_engine.pytorch.constants import TE_DType_To_Torch +import transformer_engine_torch as tex + +# Check available recipes +fp8_available, reason_for_no_fp8 = te.is_fp8_available(return_reason=True) +fp8_block_scaling_available, reason_for_no_fp8_block_scaling = te.is_fp8_block_scaling_available( + return_reason=True +) +mxfp8_available, reason_for_no_mxfp8 = te.is_mxfp8_available(return_reason=True) +nvfp4_available, reason_for_no_nvfp4 = te.is_nvfp4_available(return_reason=True) + +_quantization_params = [ + pytest.param( + "fp8_delayed_scaling", + marks=pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8), + ), + pytest.param( + "fp8_current_scaling", + marks=pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8), + ), + pytest.param( + "fp8_blockwise", + marks=pytest.mark.skipif( + not fp8_block_scaling_available, reason=reason_for_no_fp8_block_scaling + ), + ), + pytest.param( + "mxfp8", + marks=pytest.mark.skipif(not mxfp8_available, reason=reason_for_no_mxfp8), + ), + pytest.param( + "nvfp4", + marks=pytest.mark.skipif(not nvfp4_available, reason=reason_for_no_nvfp4), + ), +] + + +def make_quantizer(quantization: str, num_tensors: int, shape: List[Tuple[int, int]]) -> Quantizer: + """Create quantizers for given quantization scheme""" + + if quantization == "fp8_delayed_scaling": + quantizer = Float8Quantizer( + scale=torch.ones(1, dtype=torch.float32, device="cuda"), + amax=torch.zeros(1, dtype=torch.float32, device="cuda"), + fp8_dtype=tex.DType.kFloat8E4M3, + ) + elif quantization == "fp8_current_scaling": + quantizer = Float8CurrentScalingQuantizer( + fp8_dtype=tex.DType.kFloat8E4M3, + device="cuda", + ) + quantizer.set_usage(rowwise=True, columnwise=False) + elif quantization == "fp8_blockwise": + quantizer = Float8BlockQuantizer( + fp8_dtype=tex.DType.kFloat8E4M3, + rowwise=True, + columnwise=False, + force_pow_2_scales=True, + amax_epsilon=0.0, + block_scaling_dim=1, + ) + elif quantization == "mxfp8": + quantizer = MXFP8Quantizer(fp8_dtype=tex.DType.kFloat8E4M3) + elif quantization == "nvfp4": + quantizer = NVFP4Quantizer( + with_rht=False, + with_post_rht_amax=False, + with_2d_quantization=False, + stochastic_rounding=False, + with_random_sign_mask=False, + ) + else: + raise ValueError(f"Unknown quantization scheme: {quantization}") + + quantizer.internal = False + + return quantizer + + +def _get_rowwise_data_tensor(qtensor, quantization: str) -> torch.Tensor: + if quantization in ("fp8_delayed_scaling", "fp8_current_scaling"): + return qtensor._data + if quantization in ("fp8_blockwise", "mxfp8", "nvfp4"): + return qtensor._rowwise_data + raise ValueError(f"Unknown quantization scheme: {quantization}") + + +def _rowwise_offset_bytes(numel: int, quantization: str) -> int: + if quantization == "nvfp4": + return numel // 2 + return numel + + +class TestGroupedTensor: + @staticmethod + def setup_class(cls) -> None: + # Configure RNG + seed = 1234 + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + + def test_basic_construction_all_same_shape(self) -> None: + """Test GroupedTensor construction with all tensors having same shape""" + num_tensors = 4 + shape = [(256, 512) for _ in range(num_tensors)] + + grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes( + num_tensors=num_tensors, + shape=shape, + quantizer=None, + device="cuda", + dtype=torch.float32, + ) + + assert grouped_tensor.num_tensors == num_tensors + assert grouped_tensor.all_same_shape() + assert grouped_tensor.all_same_first_dim() + assert grouped_tensor.all_same_last_dim() + assert grouped_tensor.logical_shape == (num_tensors * 256, 512) + assert grouped_tensor.get_common_first_dim() == 256 + assert grouped_tensor.get_common_last_dim() == 512 + assert grouped_tensor.has_data() + + def test_basic_construction_varying_first_dim(self) -> None: + """Test GroupedTensor construction with varying first dimension""" + num_tensors = 3 + shape = [(128, 512), (256, 512), (384, 512)] + + grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes( + num_tensors=num_tensors, + shape=shape, + quantizer=None, + device="cuda", + dtype=torch.float32, + ) + + assert grouped_tensor.num_tensors == num_tensors + assert not grouped_tensor.all_same_shape() + assert not grouped_tensor.all_same_first_dim() + assert grouped_tensor.all_same_last_dim() + assert grouped_tensor.get_common_last_dim() == shape[0][1] + assert grouped_tensor.logical_shape == ( + sum(v for v, _ in shape), + shape[0][1], + ) # sum of first dims + + def test_split_into_quantized_tensors_no_quantization(self) -> None: + """Test split_into_quantized_tensors for unquantized tensors""" + num_tensors = 3 + shape = [(256, 512) for _ in range(num_tensors)] + + grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes( + num_tensors=num_tensors, + shape=shape, + quantizer=None, + device="cuda", + dtype=torch.float32, + ) + + # Get the original data pointer + original_data_ptr = grouped_tensor.data.data_ptr() + + # Split into tensors + tensors = grouped_tensor.split_into_quantized_tensors() + + assert len(tensors) == num_tensors + + # Verify each tensor has correct shape and shares storage + for i, tensor in enumerate(tensors): + assert tensor.shape == shape[i] + assert isinstance(tensor, torch.Tensor) + assert not hasattr(tensor, "_data") # Not a quantized tensor + + # Verify data pointer is within the original grouped tensor storage + # The tensor should be a view of the original data + assert tensor.data_ptr() >= original_data_ptr + + # Calculate expected offset + expected_offset = i * (shape[i][0] * shape[i][1]) * tensor.element_size() + assert tensor.data_ptr() == original_data_ptr + expected_offset + + @pytest.mark.parametrize("quantization", _quantization_params) + def test_split_into_quantized_tensors_quantized(self, quantization: str) -> None: + """Test split_into_quantized_tensors for quantized tensors""" + num_tensors = 3 + shape = [(512, 512) for _ in range(num_tensors)] + quantizers = make_quantizer(quantization, num_tensors, shape) + + grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes( + num_tensors=num_tensors, + shape=shape, + quantizer=quantizers, + device="cuda", + ) + + # Get the original data pointer + original_data_ptr = grouped_tensor.data.data_ptr() + + # Split into tensors + tensors = grouped_tensor.split_into_quantized_tensors() + + assert len(tensors) == num_tensors + + # Verify each tensor shares storage with the grouped tensor + for i, tensor in enumerate(tensors): + rowwise_data = _get_rowwise_data_tensor(tensor, quantization) + assert rowwise_data is not None + assert rowwise_data.data_ptr() >= original_data_ptr + numel = shape[i][0] * shape[i][1] + expected_offset = _rowwise_offset_bytes(i * numel, quantization) + assert rowwise_data.data_ptr() == original_data_ptr + expected_offset + + def test_split_varying_shapes(self) -> None: + """Test split_into_quantized_tensors with varying shapes""" + num_tensors = 3 + shape = [(128, 512), (256, 512), (384, 512)] + + grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes( + num_tensors=num_tensors, + shape=shape, + quantizer=None, + device="cuda", + dtype=torch.float32, + ) + + original_data_ptr = grouped_tensor.data.data_ptr() + tensors = grouped_tensor.split_into_quantized_tensors() + + assert len(tensors) == num_tensors + + # Verify shapes and storage + cumulative_offset = 0 + for i, tensor in enumerate(tensors): + assert tensor.shape == shape[i] + expected_offset = cumulative_offset * tensor.element_size() + assert tensor.data_ptr() == original_data_ptr + expected_offset + cumulative_offset += shape[i][0] * shape[i][1] + + @pytest.mark.parametrize("quantization", _quantization_params) + def test_quantize_inplace(self, quantization: str) -> None: + """Test that quantize is done in-place for all recipes""" + num_tensors = 3 + shape = [(512, 512) for _ in range(num_tensors)] + quantizers = make_quantizer(quantization, num_tensors, shape) + + grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes( + num_tensors=num_tensors, + shape=shape, + quantizer=quantizers, + device="cuda", + ) + + # Get original data pointers before quantization + original_data_ptr = grouped_tensor.data.data_ptr() + original_scale_inv_ptr = grouped_tensor.scale_inv.data_ptr() + original_scale_ptr = ( + grouped_tensor.scale.data_ptr() if grouped_tensor.scale is not None else None + ) + + # Create input tensors + input_tensors = [torch.randn(s, dtype=torch.float32, device="cuda") for s in shape] + + # Quantize in place + quantized_tensors = grouped_tensor.quantize(input_tensors) + + # Verify data pointers haven't changed (in-place operation) + assert grouped_tensor.data.data_ptr() == original_data_ptr + assert grouped_tensor.scale_inv.data_ptr() == original_scale_inv_ptr + if original_scale_ptr is not None: + assert grouped_tensor.scale.data_ptr() == original_scale_ptr + + # Verify returned tensors point to the same storage + for i, qtensor in enumerate(quantized_tensors): + rowwise_data = _get_rowwise_data_tensor(qtensor, quantization) + numel = shape[i][0] * shape[i][1] + expected_offset = _rowwise_offset_bytes(i * numel, quantization) + assert rowwise_data.data_ptr() == original_data_ptr + expected_offset + + @pytest.mark.parametrize("quantization", _quantization_params) + def test_quantize_varying_shapes(self, quantization: str) -> None: + """Test quantize with varying shapes""" + num_tensors = 3 + shape = [(256, 512), (512, 512), (768, 512)] + quantizers = make_quantizer(quantization, num_tensors, shape) + + grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes( + num_tensors=num_tensors, + shape=shape, + quantizer=quantizers, + device="cuda", + ) + + # Get original data pointers + original_data_ptr = grouped_tensor.data.data_ptr() + + # Create input tensors with varying shapes + input_tensors = [torch.randn(s, dtype=torch.float32, device="cuda") for s in shape] + + # Quantize in place + quantized_tensors = grouped_tensor.quantize(input_tensors) + + # Verify data pointer hasn't changed + assert grouped_tensor.data.data_ptr() == original_data_ptr + + # Verify each tensor points to correct location + cumulative_numel = 0 + for qtensor, tensor_shape in zip(quantized_tensors, shape): + rowwise_data = _get_rowwise_data_tensor(qtensor, quantization) + expected_offset = _rowwise_offset_bytes(cumulative_numel, quantization) + assert rowwise_data.data_ptr() == original_data_ptr + expected_offset + cumulative_numel += tensor_shape[0] * tensor_shape[1] + + @pytest.mark.parametrize("quantization", _quantization_params) + def test_static_quantize_method(self, quantization: str) -> None: + """Test the static quantize method""" + num_tensors = 3 + shape = [(512, 512) for _ in range(num_tensors)] + quantizers = make_quantizer(quantization, num_tensors, shape) + + # Create input tensors + input_tensors = [torch.randn(s, dtype=torch.float32, device="cuda") for s in shape] + + # Use static quantize method + grouped_tensor = GroupedTensor.create_and_quantize( + tensors=input_tensors, + quantizer=quantizers, + device="cuda", + ) + + # Verify the grouped tensor was created correctly + assert grouped_tensor.num_tensors == num_tensors + assert grouped_tensor.has_data() + + # Verify quantized_tensors were created and point to same storage + assert grouped_tensor.quantized_tensors is not None + assert len(grouped_tensor.quantized_tensors) == num_tensors + + original_data_ptr = grouped_tensor.data.data_ptr() + for i, qtensor in enumerate(grouped_tensor.quantized_tensors): + rowwise_data = _get_rowwise_data_tensor(qtensor, quantization) + numel = shape[i][0] * shape[i][1] + expected_offset = _rowwise_offset_bytes(i * numel, quantization) + assert rowwise_data.data_ptr() == original_data_ptr + expected_offset + + @pytest.mark.parametrize( + "shape", + [[(256, 512), (512, 512), (768, 512)], [(512, 512), (512, 512), (512, 512)]], + ) + @pytest.mark.skipif(not mxfp8_available, reason=reason_for_no_mxfp8) + def test_quantize_grouped_mxfp8(self, shape: List[Tuple[int, int]]) -> None: + """Test grouped quantization for MXFP8 against per-tensor quantization.""" + # Test wont pass until the grouped quantization PR from Oleg is merged. + num_tensors = 2 + shape = [(512, 1024) for _ in range(num_tensors)] + + # Create BF16 input tensors and pack into a grouped tensor + input_tensors = [torch.randn(s, dtype=torch.bfloat16, device="cuda") for s in shape] + quantized_tensors = [ + MXFP8Quantizer(fp8_dtype=tex.DType.kFloat8E4M3)(tensor) for tensor in input_tensors + ] + grouped_input = GroupedTensor.make_grouped_tensor_with_shapes( + num_tensors=num_tensors, + shape=shape, + quantizer=None, + device="cuda", + dtype=torch.bfloat16, + ) + + offset = 0 + for tensor in input_tensors: + numel = tensor.numel() + grouped_input.data[offset : offset + numel].copy_(tensor.reshape(-1)) + offset += numel + + # Create MXFP8 output grouped tensor (rowwise only for easier validation) + quantizers = [MXFP8Quantizer(fp8_dtype=tex.DType.kFloat8E4M3) for _ in range(num_tensors)] + + grouped_output = GroupedTensor.make_grouped_tensor_with_shapes( + num_tensors=num_tensors, + shape=shape, + quantizer=quantizers, + device="cuda", + ) + + # Quantize using grouped API (handle both 2-arg and 3-arg bindings) + _ = tex.quantize_grouped(grouped_input, grouped_output) + # Build expected output by quantizing each tensor independently + expected_data = [] + expected_scale_inv = [] + for tensor, quantizer in zip(input_tensors, quantizers): + qtensor = quantizer(tensor) + expected_data.append(qtensor._rowwise_data.reshape(-1)) + expected_scale_inv.append(qtensor._rowwise_scale_inv.reshape(-1)) + + expected_data = torch.cat(expected_data) + expected_scale_inv = torch.cat(expected_scale_inv) + + assert torch.equal(grouped_output.data, expected_data) + assert torch.equal(grouped_output.scale_inv, expected_scale_inv) + + def test_clear(self) -> None: + """Test clear method""" + num_tensors = 3 + shape = [(256, 512) for _ in range(num_tensors)] + + grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes( + num_tensors=num_tensors, + shape=shape, + quantizer=None, + device="cuda", + dtype=torch.float32, + ) + + assert grouped_tensor.has_data() + assert grouped_tensor.num_tensors == num_tensors + + grouped_tensor.clear() + + assert not grouped_tensor.has_data() + assert grouped_tensor.num_tensors == 0 + assert grouped_tensor.data is None + assert grouped_tensor.logical_shape == (0, 0) diff --git a/tests/pytorch/test_sanity.py b/tests/pytorch/test_sanity.py index e9d24c1a8e..f5bb47ab5c 100644 --- a/tests/pytorch/test_sanity.py +++ b/tests/pytorch/test_sanity.py @@ -2,7 +2,7 @@ # # See LICENSE for license information. -from typing import Optional +from typing import Optional, List import torch import pytest @@ -137,6 +137,117 @@ def reset_global_fp8_state(): FP8GlobalStateManager.reset() +def check_grouped_tensor_pointers_helper(tensors, num_elems_in_byte=1, tensor_name="tensor"): + """ + Verify that tensors are stored in contiguous memory. + + Args: + tensors: List or iterable of tensors to check + num_elems_in_byte: Number of elements packed per byte (1 for normal, 2 for NVFP4) + tensor_name: Name to use in error messages + """ + tensor_list = list(tensors) + if len(tensor_list) < 2: + return # Nothing to check + + for i in range(1, len(tensor_list)): + prev_tensor = tensor_list[i - 1] + curr_tensor = tensor_list[i] + + # Calculate expected offset based on previous tensor size + prev_numel = prev_tensor.numel() + expected_offset = (prev_numel // num_elems_in_byte) * prev_tensor.element_size() + + # Verify current tensor's data pointer is correctly offset + expected_ptr = prev_tensor.data_ptr() + expected_offset + actual_ptr = curr_tensor.data_ptr() + + assert ( + actual_ptr == expected_ptr + ), f"{tensor_name} {i} data pointer mismatch: expected {expected_ptr}, got {actual_ptr}" + + +def check_grouped_tensor_pointers( + weights: List[torch.Tensor], fp8_recipe: Optional[recipe.Recipe] = None +): + """ + Verify that the pointers of the weights are in contiguous memory for GroupedTensor. + TODO(ksivaman): This check can be made way more efficient but for now leaving the brute force approach. + """ + + num_elems_in_a_data_byte = 1 if fp8_recipe is None else 2 if fp8_recipe.nvfp4() else 1 + + # Check data. + if hasattr(weights[0], "_data") and weights[0]._data is not None: + data_tensors = [w._data for w in weights] + check_grouped_tensor_pointers_helper(data_tensors, num_elems_in_byte=1, tensor_name="data") + + # Check transpose. + if hasattr(weights[0], "_transpose") and weights[0]._transpose is not None: + transpose_tensors = [w._transpose for w in weights] + check_grouped_tensor_pointers_helper( + transpose_tensors, num_elems_in_byte=1, tensor_name="transpose" + ) + + # Check scale_inv. + if hasattr(weights[0], "_scale_inv") and weights[0]._scale_inv is not None: + scale_inv_tensors = [w._scale_inv for w in weights] + check_grouped_tensor_pointers_helper( + scale_inv_tensors, num_elems_in_byte=1, tensor_name="scale_inv" + ) + + # Check rowwise scale_inv. + if hasattr(weights[0], "_rowwise_scale_inv") and weights[0]._rowwise_scale_inv is not None: + scale_inv_tensors = [w._rowwise_scale_inv for w in weights] + check_grouped_tensor_pointers_helper( + scale_inv_tensors, num_elems_in_byte=1, tensor_name="rowwise_scale_inv" + ) + + # Check columnwise scale_inv. + if ( + hasattr(weights[0], "_columnwise_scale_inv") + and weights[0]._columnwise_scale_inv is not None + ): + columnwise_scale_inv_tensors = [w._columnwise_scale_inv for w in weights] + check_grouped_tensor_pointers_helper( + columnwise_scale_inv_tensors, + num_elems_in_byte=1, + tensor_name="columnwise scale_inv", + ) + + # Check rowwise amax. + if hasattr(weights[0], "_rowwise_amax") and weights[0]._rowwise_amax is not None: + rowwise_amax_tensors = [w._rowwise_amax for w in weights] + check_grouped_tensor_pointers_helper( + rowwise_amax_tensors, num_elems_in_byte=1, tensor_name="rowwise amax" + ) + + # Check columnwise amax. + if hasattr(weights[0], "_columnwise_amax") and weights[0]._columnwise_amax is not None: + columnwise_amax_tensors = [w._columnwise_amax for w in weights] + check_grouped_tensor_pointers_helper( + columnwise_amax_tensors, num_elems_in_byte=1, tensor_name="columnwise amax" + ) + + # Check rowwise data. + if hasattr(weights[0], "_rowwise_data") and weights[0]._rowwise_data is not None: + rowwise_data_tensors = [w._rowwise_data for w in weights] + check_grouped_tensor_pointers_helper( + rowwise_data_tensors, + num_elems_in_byte=num_elems_in_a_data_byte, + tensor_name="rowwise data", + ) + + # Check columnwise data. + if hasattr(weights[0], "_columnwise_data") and weights[0]._columnwise_data is not None: + columnwise_data_tensors = [w._columnwise_data for w in weights] + check_grouped_tensor_pointers_helper( + columnwise_data_tensors, + num_elems_in_byte=num_elems_in_a_data_byte, + tensor_name="columnwise data", + ) + + def _test_sanity_e2e_amp(block, dtype, config, fp8_recipe, skip_wgrad): te_inp_hidden_states = torch.randn( (config.max_seqlen_q, config.batch_size, config.hidden_size), @@ -495,9 +606,17 @@ def test_sanity_grouped_linear( use_fp8 = fp8_recipe is not None with quantized_model_init(enabled=use_fp8 and fp8_model_params, recipe=fp8_recipe): te_grouped_linear = GroupedLinear( - num_gemms, config.hidden_size, ffn_hidden_size, bias=use_bias, params_dtype=dtype + num_gemms, + config.hidden_size, + ffn_hidden_size, + bias=use_bias, + params_dtype=dtype, ).cuda() + # Verify that weights are stored in contiguous GroupedTensor storage. + weights = [getattr(te_grouped_linear, f"weight{i}") for i in range(num_gemms)] + check_grouped_tensor_pointers(weights, fp8_recipe) + inp_hidden_states = torch.randn( num_tokens, config.hidden_size, dtype=dtype, requires_grad=True ).cuda() @@ -956,7 +1075,13 @@ def test_replace_raw_data_for_float8tensor(): random_bf16_data = torch.randn(fp8_tensor.shape, dtype=torch.bfloat16, device="cuda") fp8_quantizer.update_quantized(random_bf16_data, fp8_tensor) - attrs_to_check = ["_quantizer", "_fp8_dtype", "_scale_inv", "_transpose", "_transpose_invalid"] + attrs_to_check = [ + "_quantizer", + "_fp8_dtype", + "_scale_inv", + "_transpose", + "_transpose_invalid", + ] attrs = {} for attr in attrs_to_check: attrs[attr] = getattr(fp8_tensor, attr) diff --git a/transformer_engine/common/recipe/__init__.py b/transformer_engine/common/recipe/__init__.py index 64ee2a5a16..18577b0eb4 100644 --- a/transformer_engine/common/recipe/__init__.py +++ b/transformer_engine/common/recipe/__init__.py @@ -88,33 +88,40 @@ class Recipe: Base recipe class. """ - def nvfp4(self): + @classmethod + def nvfp4(cls): """Whether the given recipe is NVFP4 1D block scaling.""" - return isinstance(self, NVFP4BlockScaling) + return issubclass(cls, NVFP4BlockScaling) - def mxfp8(self): + @classmethod + def mxfp8(cls): """Whether the given recipe is MXFP8 block scaling.""" - return isinstance(self, MXFP8BlockScaling) + return issubclass(cls, MXFP8BlockScaling) - def delayed(self): + @classmethod + def delayed(cls): """Whether the given recipe is delayed scaling.""" - return isinstance(self, DelayedScaling) + return issubclass(cls, DelayedScaling) - def float8_current_scaling(self): + @classmethod + def float8_current_scaling(cls): """Whether the given recipe is (per-tensor) current scaling.""" - return isinstance(self, Float8CurrentScaling) + return issubclass(cls, Float8CurrentScaling) - def float8_per_tensor_scaling(self): + @classmethod + def float8_per_tensor_scaling(cls): """Whether the given recipe is per-tensor scaling.""" - return isinstance(self, (DelayedScaling, Float8CurrentScaling)) + return issubclass(cls, (DelayedScaling, Float8CurrentScaling)) - def float8_block_scaling(self): + @classmethod + def float8_block_scaling(cls): """Whether the given recipe is float8 blockwise scaling.""" - return isinstance(self, Float8BlockScaling) + return issubclass(cls, Float8BlockScaling) - def custom(self): + @classmethod + def custom(cls): """Whether the given recipe is custom.""" - return isinstance(self, CustomRecipe) + return issubclass(cls, CustomRecipe) @dataclass() diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index c9ceb714e3..1709bf1b37 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -13,6 +13,7 @@ import transformer_engine_torch as tex from transformer_engine.common.recipe import Recipe +from transformer_engine.pytorch.tensor.storage.grouped_tensor import GroupedTensor from .base import ( get_dummy_wgrad, TransformerEngineBaseModule, @@ -147,7 +148,10 @@ def forward( # tensors (like scales), but bulk allocation shares storage across all tensors, # so if scales can't be offloaded, nothing in the group can be offloaded. inputmats = tex.split_quantize( - inp_view, m_splits, input_quantizers, disable_bulk_allocation=cpu_offloading + inp_view, + m_splits, + input_quantizers, + disable_bulk_allocation=cpu_offloading, ) elif debug: inputmats = DebugQuantizer.multi_tensor_quantize( @@ -365,7 +369,10 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], for i in range(ctx.num_gemms): grad_biases[i] = grad_output_mats[i].sum(dim=0) grad_output = DebugQuantizer.multi_tensor_quantize( - grad_output_view, ctx.grad_output_quantizers, ctx.m_splits, ctx.activation_dtype + grad_output_view, + ctx.grad_output_quantizers, + ctx.m_splits, + ctx.activation_dtype, ) else: # Only split grad output. Grad bias is fused with @@ -436,7 +443,8 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], if ctx.input_quantizers[0] is not None: for input_quantizer in ctx.input_quantizers: if isinstance( - input_quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer) + input_quantizer, + (Float8Quantizer, Float8CurrentScalingQuantizer), ): input_quantizer.set_usage(rowwise=True, columnwise=True) else: @@ -446,7 +454,10 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], inputmats = tex.split_quantize(inp_view, ctx.m_splits, ctx.input_quantizers) elif ctx.debug: inputmats = DebugQuantizer.multi_tensor_quantize( - inp_view, ctx.input_quantizers, ctx.m_splits, ctx.activation_dtype + inp_view, + ctx.input_quantizers, + ctx.m_splits, + ctx.activation_dtype, ) else: inputmats = torch.split( @@ -616,7 +627,7 @@ def __init__( ) -> None: super().__init__(name) - params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype + self.params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype self.num_gemms = num_gemms self.in_features = in_features self.out_features = out_features @@ -631,12 +642,19 @@ def __init__( assert ( not ub_overlap_rs and not ub_overlap_ag ), "GroupedLinear doesn't support Userbuffer overlap." + self.init_method = init_method self.get_rng_state_tracker = get_rng_state_tracker self.rng_tracker_name = rng_tracker_name self.wgrad_store = WeightGradStore(delay_wgrad_compute) - self._offsets = {"input": 0, "weight": 1, "output": 2, "grad_output": 0, "grad_input": 1} + self._offsets = { + "input": 0, + "weight": 1, + "output": 2, + "grad_output": 0, + "grad_input": 1, + } self._num_fp8_tensors_per_gemm = { "fwd": 3, "bwd": 2, @@ -678,7 +696,7 @@ def __init__( self.out_features, self.in_features, device=device, - dtype=params_dtype, + dtype=self.params_dtype, ), ), init_fn=init_method, @@ -694,13 +712,13 @@ def __init__( torch.empty( self.out_features, device=device, - dtype=params_dtype, + dtype=self.params_dtype, ), ), init_fn=init_method_constant(0.0), ) else: - bias = torch.Tensor().to(dtype=params_dtype, device=device) + bias = torch.Tensor().to(dtype=self.params_dtype, device=device) setattr(self, f"bias{i}", bias) if self.primary_weights_in_fp8: @@ -724,8 +742,49 @@ def set_meta_tensor(self, fwd: bool, recipe: Recipe) -> None: if recipe.float8_current_scaling(): self._customize_quantizers_float8_current_scaling(fwd, recipe) + def make_grouped_weights(self, defer_init=False) -> None: + """ + Convert parameters into a GroupedTensor and re-register them as parameters. + """ + + if defer_init: + return + + weights = [getattr(self, f"weight{i}") for i in range(self.num_gemms)] + weight_quantizers = self._get_weight_quantizers() + + # Create the weight storage. + grouped_weights = GroupedTensor.make_grouped_tensor_with_shapes( + num_tensors=self.num_gemms, + shape=[(self.out_features, self.in_features)] * self.num_gemms, + quantizer=weight_quantizers[0], + dtype=self.params_dtype, + ) + + # Copy existing params into storage. + # TODO(ksivamani): Verify correctness of copy for all recipes. + with torch.no_grad(): + for i in range(self.num_gemms): + grouped_weights.quantized_tensors[i].copy_(weights[i]) + + # Re-register the grouped weights as parameters. + for i in range(self.num_gemms): + self.register_parameter( + f"weight{i}", + torch.nn.Parameter(grouped_weights.quantized_tensors[i]), + init_fn=self.init_method, + get_rng_state_tracker=self.get_rng_state_tracker, + fp8_meta_index=self._offsets["weight"] + i * self._num_fp8_tensors_per_gemm["fwd"], + ) + + self.set_tensor_parallel_attributes(defer_init=defer_init) + def reset_parameters(self, defer_init=False): super().reset_parameters(defer_init=defer_init) + self.make_grouped_weights(defer_init=defer_init) + + def set_tensor_parallel_attributes(self, defer_init=False) -> None: + """Set attributes needed for TP""" if not defer_init: # Set parallelism attributes for linear weights diff --git a/transformer_engine/pytorch/tensor/float8_tensor.py b/transformer_engine/pytorch/tensor/float8_tensor.py index 3aeace0a77..55bca49af3 100644 --- a/transformer_engine/pytorch/tensor/float8_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_tensor.py @@ -11,7 +11,11 @@ import transformer_engine_torch as tex from transformer_engine_torch import DType as TE_DType -from transformer_engine.common.recipe import DelayedScaling, Float8CurrentScaling, Recipe +from transformer_engine.common.recipe import ( + DelayedScaling, + Float8CurrentScaling, + Recipe, +) from ..utils import canonicalize_process_group, devices_match from .storage.float8_tensor_storage import Float8TensorStorage, _FromFloat8Func from ..quantized_tensor import QuantizedTensor, Quantizer @@ -154,6 +158,10 @@ def calibrate(self, tensor: torch.Tensor) -> None: amin, amax = tensor.aminmax() self.amax.copy_(torch.max(-amin, amax)) + def get_columnwise_shape(self, rowwise_data_shape: Iterable[int]) -> Tuple[int, ...]: + """Calculate the shape of the columnwise data for Float8 1D blockwise quantization.""" + return [rowwise_data_shape[-1]] + list(rowwise_data_shape[:-1]) + def create_tensor_from_data( self, data: torch.Tensor, @@ -408,6 +416,10 @@ def create_tensor_from_data( quantizer=self, ) + def get_columnwise_shape(self, rowwise_data_shape: Iterable[int]) -> Tuple[int, ...]: + """Calculate the shape of the columnwise data for Float8 1D blockwise quantization.""" + return [rowwise_data_shape[-1]] + list(rowwise_data_shape[:-1]) + def onnx_quantize(self, tensor: torch.Tensor) -> QuantizedTensor: """Function using primitives with ONNX defined translations.""" if tensor.dtype != torch.float32: @@ -769,7 +781,10 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): kwargs, ) return Float8Tensor.make_like( - tensor, data=func_out, data_transpose=func_transposed_out, shape=func_out.shape + tensor, + data=func_out, + data_transpose=func_transposed_out, + shape=func_out.shape, ) if func == torch.ops.aten.detach.default: diff --git a/transformer_engine/pytorch/tensor/mxfp8_tensor.py b/transformer_engine/pytorch/tensor/mxfp8_tensor.py index 8dd2255d89..41d6c87f2b 100644 --- a/transformer_engine/pytorch/tensor/mxfp8_tensor.py +++ b/transformer_engine/pytorch/tensor/mxfp8_tensor.py @@ -164,6 +164,49 @@ def calibrate(self, tensor: torch.Tensor) -> None: # TODO(ksivamani): No calibration needed for mxfp8? pass + def get_scale_shape( + self, + shape: Iterable[int], + columnwise: bool, + ) -> Tuple[int, int]: + """Calculate the shape of the scaling tensor for MXFP8 1D blockwise quantization. + + This method determines the shape of the scaling tensor needed for blockwise quantization, + taking into account the input tensor shape and whether columnwise scaling is used. + + Parameters + ---------- + shape : Iterable[int] + Shape of the input tensor to be quantized + columnwise : bool + Whether to use columnwise scaling (True) or rowwise scaling (False) + + Returns + ------- + Tuple[int, int] + Shape of the scaling tensor as (outer_dim, inner_dim) + For MXFP8 1D blockwise quantization, blocksize is 32 + Swizzle kernel will be performed before GEMM to suit the need of CuBLAS. + CuBLAS doc: https://docs.nvidia.com/cuda/cublas/index.html#d-block-scaling-factors-layout + """ + if columnwise: + # Columnwise: scale_inv shape is [prod(shape[:-1]) // BLOCK_SIZE, shape[-1]] + # with padding to multiples of [4, 128] + return ( + round_up_to_nearest_multiple(math.prod(shape[:-1]) // MXFP8_BLOCK_SCALING_SIZE, 4), + round_up_to_nearest_multiple(shape[-1], 128), + ) + # Rowwise: scale_inv shape is [prod(shape[:-1]), shape[-1] // BLOCK_SIZE] + # with padding to multiples of [128, 4] + return ( + round_up_to_nearest_multiple(math.prod(shape[:-1]), 128), + round_up_to_nearest_multiple(shape[-1] // MXFP8_BLOCK_SCALING_SIZE, 4), + ) + + def get_columnwise_shape(self, rowwise_data_shape: Tuple[int, ...]) -> Tuple[int, ...]: + """Calculate the shape of the columnwise data for MXFP8 1D blockwise quantization.""" + return rowwise_data_shape + def create_tensor_from_data( self, data: torch.Tensor, @@ -704,7 +747,7 @@ def fsdp_post_all_gather( columnwise_scale_inv=columnwise_scale_inv, fp8_dtype=fp8_dtype, dtype=param_dtype, - shape=rowwise_data.shape if rowwise_data is not None else columnwise_data.shape, + shape=(rowwise_data.shape if rowwise_data is not None else columnwise_data.shape), quantizer=self._quantizer, with_gemm_swizzled_scales=False, ) diff --git a/transformer_engine/pytorch/tensor/nvfp4_tensor.py b/transformer_engine/pytorch/tensor/nvfp4_tensor.py index 101cf78a8f..66f986a900 100644 --- a/transformer_engine/pytorch/tensor/nvfp4_tensor.py +++ b/transformer_engine/pytorch/tensor/nvfp4_tensor.py @@ -341,7 +341,10 @@ def make_empty( ) columnwise_scale_shape = self.get_scale_shape(shape, columnwise=True) columnwise_scale_inv = torch.empty( - columnwise_scale_shape, dtype=torch.uint8, device=device, pin_memory=pin_memory + columnwise_scale_shape, + dtype=torch.uint8, + device=device, + pin_memory=pin_memory, ) amax_columnwise = torch.zeros( 1, dtype=torch.float32, device=device, pin_memory=pin_memory diff --git a/transformer_engine/pytorch/tensor/storage/__init__.py b/transformer_engine/pytorch/tensor/storage/__init__.py index d7a2719200..7c8a014c1d 100644 --- a/transformer_engine/pytorch/tensor/storage/__init__.py +++ b/transformer_engine/pytorch/tensor/storage/__init__.py @@ -7,3 +7,4 @@ from .mxfp8_tensor_storage import MXFP8TensorStorage # noqa: F401 from .float8_blockwise_tensor_storage import Float8BlockwiseQTensorStorage # noqa: F401 from .nvfp4_tensor_storage import NVFP4TensorStorage # noqa: F401 +from .grouped_tensor import GroupedTensor # noqa: F401 diff --git a/transformer_engine/pytorch/tensor/storage/grouped_tensor.py b/transformer_engine/pytorch/tensor/storage/grouped_tensor.py new file mode 100644 index 0000000000..c6f345e896 --- /dev/null +++ b/transformer_engine/pytorch/tensor/storage/grouped_tensor.py @@ -0,0 +1,942 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Grouped tensor class for handling collections of tensors with different shapes""" +from __future__ import annotations +from typing import Optional, Tuple, List, Union +import math + +import torch + +from ...quantized_tensor import QuantizedTensorStorage, Quantizer + +from ..mxfp8_tensor import MXFP8Tensor +from ..nvfp4_tensor import NVFP4Tensor +from ..float8_tensor import Float8Tensor +from ..float8_blockwise_tensor import Float8BlockwiseQTensor +from .float8_tensor_storage import Float8TensorStorage +from .mxfp8_tensor_storage import MXFP8TensorStorage +from .float8_blockwise_tensor_storage import Float8BlockwiseQTensorStorage +from .nvfp4_tensor_storage import NVFP4TensorStorage + + +class GroupedTensor: + """ + EXPERIMENTAL FEATURE AND SUBJECT TO CHANGE. + + Grouped tensor is a collection of tensors with different shapes but the same dtype and scaling mode. + + Shape Representation: + - logical_shape: 2D shape representing the conceptual layout, i.e. the shape when member tensors + are flattened to 2D and stacked together (REQUIRED) + + When all_same_shape(): [num_tensors * M, N] where each tensor is (M, N) + + When varying_first_dim(): [~sum_of_first_dims, N] where N is common + + When varying_last_dim(): [M, ~sum_of_last_dims] where M is common + + When varying_both_dims(): [1, total_elements] (fully flattened) + + - first_dims and last_dims are OPTIONAL (None if dimension is uniform) + + None first_dims: all tensors have the same first dimension + + None last_dims: all tensors have the same last dimension + + Both None: all tensors have identical shapes + + Both set: each tensor has unique shape (first_dims[i], last_dims[i]) + + Data Layout: + - ALL data fields are stored as 1D flattened arrays (data, columnwise_data, scale_inv, etc.) + - logical_shape provides the conceptual 2D interpretation + - All data is stored on device in contiguous layout + + Note: This structure is used only for combined storage of multiple tensors with the same dtype and scaling mode. + """ + + def __init__( + self, + num_tensors: int, + shape: List[Tuple[int, int]], + quantizer: Optional[Quantizer] = None, + dtype: Optional[torch.dtype] = None, + data: Optional[torch.Tensor] = None, + columnwise_data: Optional[torch.Tensor] = None, + scale_inv: Optional[torch.Tensor] = None, + columnwise_scale_inv: Optional[torch.Tensor] = None, + amax: Optional[torch.Tensor] = None, + columnwise_amax: Optional[torch.Tensor] = None, + scale: Optional[torch.Tensor] = None, + first_dims: Optional[torch.Tensor] = None, + last_dims: Optional[torch.Tensor] = None, + tensor_offsets: Optional[torch.Tensor] = None, + offsets: Optional[List[int]] = None, + scale_inv_offsets: Optional[List[int]] = None, + columnwise_scale_inv_offsets: Optional[List[int]] = None, + logical_shape: Optional[Tuple[int, int]] = None, + ) -> None: + """ + Initialize a GroupedTensor. + + Args: + num_tensors: Number of tensors in the group + shape: 2D shape of each tensor (len num_tensors) + quantizer: Quantizer for the grouped tensor + data: Row-wise data buffer (1D flattened) + columnwise_data: Column-wise data buffer (1D flattened) + scale_inv: Row-wise scale inverse buffer + columnwise_scale_inv: Column-wise scale inverse buffer + amax: Row-wise amax buffer + columnwise_amax: Column-wise amax buffer + scale: Scale buffer (for FP8-DS only) + first_dims: Device tensor of int64 array of length num_tensors (or None if uniform) + last_dims: Device tensor of int64 array of length num_tensors (or None if uniform) + tensor_offsets: Device tensor of int64 array of length num_tensors (or None if uniform) + offsets: Vector of integer offsets for each tensor. + logical_shape: 2D tuple representing conceptual shape + """ + self.num_tensors = num_tensors + self.quantizer = quantizer + self.shape = shape + self.dtype = ( + dtype if dtype is not None else torch.float32 + ) # Default to float32 if not provided + + # Data buffers + self.data = data + self.columnwise_data = columnwise_data + self.scale_inv = scale_inv + self.columnwise_scale_inv = columnwise_scale_inv + self.amax = amax + self.columnwise_amax = columnwise_amax + self.scale = scale + + # For convenient indexing for python GroupedTensor API. + self.scale_inv_offsets = scale_inv_offsets + self.columnwise_scale_inv_offsets = columnwise_scale_inv_offsets + + # Shape information (OPTIONAL - None if dimension is uniform across all tensors) + # first_dims[i] = first dimension of tensor i (None if all tensors have same first dim) + # last_dims[i] = last dimension of tensor i (None if all tensors have same last dim) + self.first_dims = ( + first_dims # Device pointer to int64_t array of length num_tensors (or None) + ) + self.last_dims = ( + last_dims # Device pointer to int64_t array of length num_tensors (or None) + ) + + # Offsets for indexing into contiguous 1D layout (OPTIONAL - not needed if all_same_shape()) + # tensor_offsets[i] = element offset to start of tensor i (cumulative sum of numel for tensors 0..i-1) + # Usage: tensor_i_ptr = data.data_ptr() + tensor_offsets[i] * element_size + # If None and all_same_shape(): offset[i] = i * M * N (where M, N are common dimensions) + self.tensor_offsets = ( + tensor_offsets # Device pointer to int64_t array of length num_tensors (or None) + ) + self.offsets = offsets # Vector of integer offsets for each tensor. + + # Logical shape: conceptual 2D shape of the grouped data (REQUIRED) + # Represents how the 1D flattened data should be interpreted as 2D + # Always 2D with positive dimensions + self.logical_shape = logical_shape if logical_shape is not None else (0, 0) + + # Hold a reference to the quantized tensors that occupy same storage as the GroupedTensor. + # Used as a convenience. + self.quantized_tensors = None + + def has_data(self) -> bool: + """ + Check if the tensor has row-wise data. + + Returns: + True if data buffer is initialized, False otherwise + """ + return self.data is not None + + def has_columnwise_data(self) -> bool: + """ + Check if the tensor has column-wise data. + + Returns: + True if columnwise_data buffer is initialized, False otherwise + """ + return self.columnwise_data is not None + + def all_same_first_dim(self) -> bool: + """ + Check if all tensors in the group have the same first dimension. + + Returns: + True if first dimension is uniform across all tensors + """ + return self.first_dims is None + + def all_same_last_dim(self) -> bool: + """ + Check if all tensors in the group have the same last dimension. + + Returns: + True if last dimension is uniform across all tensors + """ + return self.last_dims is None + + def all_same_shape(self) -> bool: + """ + Check if all tensors in the group have identical shapes. + + Returns: + True if all tensors have the same shape + """ + return self.first_dims is None and self.last_dims is None + + def varying_both_dims(self) -> bool: + """ + Check if both dimensions vary across tensors. + + Returns: + True if both first and last dimensions vary + """ + return self.first_dims is not None and self.last_dims is not None + + def get_common_first_dim(self) -> int: + """ + Get the common first dimension when all tensors share it. + + Returns: + The common first dimension + + Raises: + RuntimeError: If first dimension varies across tensors or logical_shape is not 2D + """ + if not self.all_same_first_dim(): + raise RuntimeError("First dim varies across tensors") + if len(self.logical_shape) != 2: + raise RuntimeError("Logical shape must be 2D") + + if self.all_same_shape(): + # When both dims are uniform: logical_shape = [num_tensors * M, N] + return self.logical_shape[0] // self.num_tensors + # When varying last dims but not first dim: logical_shape = [M, sum_of_last_dims] + return self.logical_shape[0] + + def get_common_last_dim(self) -> int: + """ + Get the common last dimension when all tensors share it. + + Returns: + The common last dimension + + Raises: + RuntimeError: If last dimension varies across tensors or logical_shape is not 2D + """ + if not self.all_same_last_dim(): + raise RuntimeError("Last dim varies across tensors") + if len(self.logical_shape) != 2: + raise RuntimeError("Logical shape must be 2D") + + # For both uniform and varying first dim cases: logical_shape[1] is the common last dim + return self.logical_shape[1] + + def get_dtype(self) -> torch.dtype: + """ + Get the high precision data type of the tensor. + + Returns: + The high precision dtype of the data buffer + """ + + return self.dtype + + def clear(self) -> None: + """ + Reset tensor data and clear all buffers. + """ + self.data = None + self.columnwise_data = None + self.scale_inv = None + self.columnwise_scale_inv = None + self.amax = None + self.columnwise_amax = None + self.scale = None + self.first_dims = None + self.last_dims = None + self.tensor_offsets = None + self.logical_shape = (0, 0) + self.num_tensors = 0 + self.quantizer = None + self.quantized_tensors = None + self.offsets = None + self.scale_inv_offsets = None + self.columnwise_scale_inv_offsets = None + + def __repr__(self) -> str: + """String representation of the GroupedTensor.""" + return ( + f"GroupedTensor(num_tensors={self.num_tensors}, " + f"shape={self.shape}, " + f"logical_shape={self.logical_shape}, " + f"dtype={self.get_dtype()})" + ) + + def __str__(self) -> str: + """User-friendly string representation.""" + shape_info = [] + if self.all_same_shape(): + shape_info.append("uniform shape") + else: + if not self.all_same_first_dim(): + shape_info.append("varying first dim") + if not self.all_same_last_dim(): + shape_info.append("varying last dim") + + return ( + f"GroupedTensor with {self.num_tensors} tensors " + f"({', '.join(shape_info) if shape_info else 'uniform'}), " + f"logical_shape={self.logical_shape}, " + f"dtype={self.get_dtype()}" + ) + + @staticmethod + def make_grouped_tensor_with_shapes( + num_tensors: int, + shape: List[Tuple[int, int]], + quantizer: Optional[Quantizer] = None, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ) -> GroupedTensor: + """ + Create a GroupedTensor for storing multiple weight tensors of the same shape. + + Args: + num_tensors: Number of tensors + shape: 2D shape of each tensor (len num_tensors) + quantizer: Quantizer for each tensor + device: Device to allocate tensors on, defaults to current cuda device + dtype: Data type of the tensor (for high precision case) + + Returns: + A GroupedTensor. + """ + + # First dim + first_dim_list = [s[0] for s in shape] + uniform_first_dim = all(first_dim_list[0] == x for x in first_dim_list) + logical_first_dim = sum(first_dim_list) + if uniform_first_dim: + first_dims = None + else: + first_dims = torch.tensor([s[0] for s in shape], dtype=torch.int64, device=device) + + # Last dim + last_dim_list = [s[1] for s in shape] + logical_last_dim = last_dim_list[0] + assert all(logical_last_dim == x for x in last_dim_list), "Last dims should be uniform" + + return GroupedTensor.make_grouped_tensor( + num_tensors=num_tensors, + first_dims=first_dims, + last_dims=None, + logical_first_dim=logical_first_dim, + logical_last_dim=logical_last_dim, + quantizer=quantizer, + device=device, + dtype=dtype, + ) + + @staticmethod + def make_grouped_tensor( + num_tensors: int, + first_dims: Optional[torch.Tensor], + last_dims: Optional[torch.tensor], + logical_first_dim: int, + logical_last_dim: int, + quantizer: Optional[Quantizer] = None, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ) -> GroupedTensor: + """ + Create a GroupedTensor for storing multiple weight tensors of the same shape. + + Args: + num_tensors: Number of tensors + first_dims: Device tensor of int64 array of length num_tensors (or None if uniform) + last_dims: Device tensor of int64 array of length num_tensors (or None if uniform) + logical_first_dim: Logical first dimension + logical_last_dim: Logical last dimension + quantizer: Quantizer for each tensor + Used to figure out the recipe and what to allocate. + device: Device to allocate tensors on, defaults to current cuda device + dtype: Data type of the tensor (for high precision case) + + Returns: + A GroupedTensor. + """ + + # Set device + if device is None: + device = torch.cuda.current_device() + + # Shape patterns and validation. + all_same_first = first_dims is None + all_same_last = last_dims is None + + assert all_same_last, "Last dim must be uniform for GroupedTensor" + assert logical_first_dim > 0, "Logical first dim must be positive for GroupedTensor" + assert logical_last_dim > 0, "Logical last dim must be positive for GroupedTensor" + + # assert ( + # logical_first_dim % 128 == 0 + # ), "Logical first dim must be divisible by 128" + # assert logical_last_dim % 128 == 0, "Logical last dim must be divisible by 128" + + # Calculate tensor offsets (cumulative element offsets) + tensor_offsets = None + offsets = None + shape = [] + if not all_same_first: + # Need explicit offsets for non-uniform shapes + # Offsets are based on number of elements and not pointers. + # Kernels need to calculate precise pointers based on size of elements. + + # TODO(ksivaman): Single kernel + remove the host offset calculation. + tensor_offsets = torch.cat( + [ + torch.zeros(1, device=first_dims.device, dtype=first_dims.dtype), + torch.cumsum(first_dims * logical_last_dim, dim=0), + ] + ) + offsets = tensor_offsets.tolist() + first_dims_list = first_dims.tolist() + for i in range(num_tensors): + shape.append((first_dims_list[i], logical_last_dim)) + else: + offsets = [ + i * logical_first_dim * logical_last_dim // num_tensors + for i in range(num_tensors + 1) + ] + for i in range(num_tensors): + shape.append((logical_first_dim // num_tensors, logical_last_dim)) + + # Calculate logical shape based + logical_shape = (logical_first_dim, logical_last_dim) + + no_quantization = quantizer is None + + rowwise_usage = quantizer.rowwise_usage if not no_quantization else True + columnwise_usage = quantizer.columnwise_usage if not no_quantization else False + + # Calculate total elements across all tensors + total_elements = logical_first_dim * logical_last_dim + + data = None + columnwise_data = None + scale_inv = None + columnwise_scale_inv = None + amax = None + columnwise_amax = None + scale = None + scale_inv_offsets = None + columnwise_scale_inv_offsets = None + if no_quantization: + assert dtype is not None, "dtype must be provided for unquantized GroupedTensor" + if rowwise_usage: + # Allocate rowwise data buffer (1D flattened, uint8) + data = torch.empty(total_elements, dtype=dtype, device=device) + + if columnwise_usage: + # Allocate columnwise data buffer (1D flattened, uint8) + columnwise_data = torch.empty(total_elements, dtype=dtype, device=device) + elif quantizer._get_compatible_recipe().mxfp8(): + if rowwise_usage: + # Allocate rowwise data buffer (1D flattened, uint8) + data = torch.empty(total_elements, dtype=torch.uint8, device=device) + # Scale inverse buffer for MXFP8 - complex shape based on block scaling + # For grouped tensors, we need to calculate scale_inv size for all tensors + total_scale_elements = 0 + scale_inv_offsets = [0] + for i, s in enumerate(shape): + scale_inv_shape = quantizer.get_scale_shape(s, False) + scale_elements = math.prod(scale_inv_shape) + total_scale_elements += scale_elements + if i < num_tensors - 1: + scale_inv_offsets.append(total_scale_elements) + scale_inv = torch.empty(total_scale_elements, dtype=torch.uint8, device=device) + + if columnwise_usage: + # Allocate columnwise data buffer (1D flattened, uint8) + columnwise_data = torch.empty(total_elements, dtype=torch.uint8, device=device) + # Columnwise scale inverse buffer + total_columnwise_scale_elements = 0 + columnwise_scale_inv_offsets = [0] + for i, s in enumerate(shape): + scale_inv_shape = quantizer.get_scale_shape(s, False) + columnwise_scale_elements = math.prod(scale_inv_shape) + total_columnwise_scale_elements += columnwise_scale_elements + if i < num_tensors - 1: + columnwise_scale_inv_offsets.append(total_columnwise_scale_elements) + columnwise_scale_inv = torch.empty( + total_columnwise_scale_elements, dtype=torch.uint8, device=device + ) + elif quantizer._get_compatible_recipe().delayed(): + if rowwise_usage: + # Allocate rowwise data buffer (1D flattened, uint8) + data = torch.empty(total_elements, dtype=torch.uint8, device=device) + # Scale inverse - one per tensor + scale_inv = torch.empty(num_tensors, dtype=torch.float32, device=device) + # One scale per tensor, so offsets are simply 0, 1, 2, ..., num_tensors-1 + scale_inv_offsets = list(range(num_tensors)) + + if columnwise_usage: + # Allocate columnwise data buffer (1D flattened, uint8) + columnwise_data = torch.empty(total_elements, dtype=torch.uint8, device=device) + # Columnwise scale inverse - one per tensor + columnwise_scale_inv = torch.empty(num_tensors, dtype=torch.float32, device=device) + # One scale per tensor, so offsets are simply 0, 1, 2, ..., num_tensors-1 + columnwise_scale_inv_offsets = list(range(num_tensors)) + + # Amax buffer for delayed scaling - one per tensor + amax = torch.empty(num_tensors, dtype=torch.float32, device=device) + elif quantizer._get_compatible_recipe().nvfp4(): + + if rowwise_usage: + # Allocate rowwise data buffer (1D flattened, uint8, but FP4 packs 2 values per byte) + data = torch.empty((total_elements) // 2, dtype=torch.uint8, device=device) + # Scale inverse buffer for NVFP4 - complex shape based on block scaling + # For simplicity, calculate total scale elements needed + total_scale_elements = 0 + scale_inv_offsets = [0] + for i, s in enumerate(shape): + scale_inv_shape = quantizer.get_scale_shape(s, False) + total_scale_elements += math.prod(scale_inv_shape) + if i < num_tensors - 1: + scale_inv_offsets.append(total_scale_elements) + scale_inv = torch.empty(total_scale_elements, dtype=torch.uint8, device=device) + # Amax buffer - one per tensor + amax = torch.empty(num_tensors, dtype=torch.float32, device=device) + + if columnwise_usage: + # Allocate columnwise data buffer (1D flattened, uint8, FP4 packed) + columnwise_data = torch.empty( + (total_elements) // 2, dtype=torch.uint8, device=device + ) + # Columnwise scale inverse buffer + total_columnwise_scale_elements = 0 + columnwise_scale_inv_offsets = [0] + for i, s in enumerate(shape): + columnwise_scale_inv_shape = quantizer.get_scale_shape(s, True) + total_columnwise_scale_elements += math.prod(columnwise_scale_inv_shape) + if i < num_tensors - 1: + columnwise_scale_inv_offsets.append(total_columnwise_scale_elements) + columnwise_scale_inv = torch.empty( + total_columnwise_scale_elements, dtype=torch.uint8, device=device + ) + # Columnwise amax buffer - one per tensor + columnwise_amax = torch.empty(num_tensors, dtype=torch.float32, device=device) + elif quantizer._get_compatible_recipe().float8_block_scaling(): + if rowwise_usage: + # Allocate rowwise data buffer (1D flattened, uint8) + data = torch.empty(total_elements, dtype=torch.uint8, device=device) + # Scale inverse - size depends on block configuration + # For simplicity, calculate total scale elements needed + total_scale_elements = 0 + scale_inv_offsets = [0] + for i, s in enumerate(shape): + scale_inv_shape = quantizer.get_scale_shape(s, False) + total_scale_elements += math.prod(scale_inv_shape) + if i < num_tensors - 1: + scale_inv_offsets.append(total_scale_elements) + scale_inv = torch.empty(total_scale_elements, dtype=torch.float32, device=device) + + if columnwise_usage: + # Allocate columnwise data buffer (1D flattened, uint8) + columnwise_data = torch.empty(total_elements, dtype=torch.uint8, device=device) + # Columnwise scale inverse + total_columnwise_scale_elements = 0 + columnwise_scale_inv_offsets = [0] + for i, s in enumerate(shape): + columnwise_scale_inv_shape = quantizer.get_scale_shape(s, True) + total_columnwise_scale_elements += math.prod(columnwise_scale_inv_shape) + if i < num_tensors - 1: + columnwise_scale_inv_offsets.append(total_columnwise_scale_elements) + columnwise_scale_inv = torch.empty( + total_columnwise_scale_elements, dtype=torch.float32, device=device + ) + elif quantizer._get_compatible_recipe().float8_current_scaling(): + # Current scaling - per-tensor scaling computed on the fly + if rowwise_usage: + # Allocate rowwise data buffer (1D flattened, uint8) + data = torch.empty(total_elements, dtype=torch.uint8, device=device) + # Scale inverse - one per tensor + scale_inv = torch.empty(num_tensors, dtype=torch.float32, device=device) + # One scale per tensor, so offsets are simply 0, 1, 2, ..., num_tensors-1 + scale_inv_offsets = list(range(num_tensors)) + + if columnwise_usage: + # Allocate columnwise data buffer (1D flattened, uint8) + columnwise_data = torch.empty(total_elements, dtype=torch.uint8, device=device) + # Columnwise scale inverse - one per tensor + columnwise_scale_inv = torch.empty(num_tensors, dtype=torch.float32, device=device) + # One scale per tensor, so offsets are simply 0, 1, 2, ..., num_tensors-1 + columnwise_scale_inv_offsets = list(range(num_tensors)) + + # Scale and amax buffers for current scaling - one per tensor + scale = torch.empty(num_tensors, dtype=torch.float32, device=device) + amax = torch.empty(num_tensors, dtype=torch.float32, device=device) + else: + raise ValueError(f"Unsupported quantizer for GroupedTensor: {quantizer}") + + grouped_tensor = GroupedTensor( + num_tensors=num_tensors, + shape=shape, + dtype=dtype, + quantizer=quantizer, + data=data, + columnwise_data=columnwise_data, + scale_inv=scale_inv, + columnwise_scale_inv=columnwise_scale_inv, + amax=amax, + columnwise_amax=columnwise_amax, + scale=scale, + first_dims=first_dims, + last_dims=last_dims, + tensor_offsets=tensor_offsets, + offsets=offsets, + scale_inv_offsets=scale_inv_offsets, + columnwise_scale_inv_offsets=columnwise_scale_inv_offsets, + logical_shape=logical_shape, + ) + + grouped_tensor.quantized_tensors = grouped_tensor.split_into_quantized_tensors() + return grouped_tensor + + def split_into_quantized_tensors( + self, + ) -> List[Union[QuantizedTensorStorage, torch.Tensor]]: + """ + Split the GroupedTensor into a list of `num_tensors` + quantized tensors based on the quantizer. No additional memory allocation is performed, + so the tensors returned are the same as the ones used to create the GroupedTensor. + + If quantizer is None, returns normal torch tensors. + If quantizer.internal is True, returns QuantizedTensorStorage. + Otherwise, returns QuantizedTensor. + + TODO(ksivaman): Block cases where any dims are varying. This is needed only + to expose the weights as separate parameters. + """ + + result = [] + + no_quantization = self.quantizer is None + + # Case 1: No quantization - return regular torch tensors + if no_quantization: + for i in range(self.num_tensors): + # Get tensor shape + tensor_shape = self.shape[i] + + # Get tensor data slice + if self.offsets is not None: + start_offset = self.offsets[i] + numel = tensor_shape[0] * tensor_shape[1] + end_offset = start_offset + numel + + if self.has_data(): + tensor_data = self.data[start_offset:end_offset].view(tensor_shape) + result.append(tensor_data) + elif self.has_columnwise_data(): + tensor_data = self.columnwise_data[start_offset:end_offset].view( + tensor_shape + ) + result.append(tensor_data) + else: + raise RuntimeError("GroupedTensor has no data to split") + else: + # All same shape case + numel = tensor_shape[0] * tensor_shape[1] + start_offset = i * numel + end_offset = start_offset + numel + + if self.has_data(): + tensor_data = self.data[start_offset:end_offset].view(tensor_shape) + result.append(tensor_data) + elif self.has_columnwise_data(): + tensor_data = self.columnwise_data[start_offset:end_offset].view( + tensor_shape + ) + result.append(tensor_data) + else: + raise RuntimeError("GroupedTensor has no data to split") + + return result + + # Case 2: Quantized tensors + recipe = self.quantizer._get_compatible_recipe() + + for i in range(self.num_tensors): + # Get tensor shape + tensor_shape = self.shape[i] + numel = tensor_shape[0] * tensor_shape[1] + + # Get data offsets + if self.offsets is not None: + data_start = self.offsets[i] + data_end = data_start + numel + else: + # All same shape + data_start = i * numel + data_end = data_start + numel + + # Special shape handling for NVFP4. + nvfp4 = self.quantizer._get_compatible_recipe().nvfp4() + if nvfp4: + data_start = data_start // 2 + data_end = data_end // 2 + + # Extract rowwise and columnwise data + rowwise_data = None + columnwise_data = None + + if self.has_data(): + if nvfp4: + rowwise_tensor_shape = self.quantizer.convert_shape_for_fp4(tensor_shape) + else: + rowwise_tensor_shape = tensor_shape + rowwise_data = self.data[data_start:data_end].view(rowwise_tensor_shape) + + if self.has_columnwise_data(): + columnwise_tensor_shape = self.quantizer.get_columnwise_shape(tensor_shape) + if nvfp4: + columnwise_tensor_shape = self.quantizer.convert_shape_for_fp4( + columnwise_tensor_shape + ) + columnwise_data = self.columnwise_data[data_start:data_end].view( + columnwise_tensor_shape + ) + + # MXFP8 format + if recipe.mxfp8(): + # Extract scale_inv data + rowwise_scale_inv = None + columnwise_scale_inv = None + + if self.scale_inv is not None and self.scale_inv_offsets is not None: + scale_start = self.scale_inv_offsets[i] + if i < self.num_tensors - 1: + scale_end = self.scale_inv_offsets[i + 1] + else: + scale_end = self.scale_inv.numel() + + # Calculate expected scale shape for MXFP8 + scale_shape = self.quantizer.get_scale_shape(tensor_shape, False) + rowwise_scale_inv = self.scale_inv[scale_start:scale_end].view(scale_shape) + + if ( + self.columnwise_scale_inv is not None + and self.columnwise_scale_inv_offsets is not None + ): + cscale_start = self.columnwise_scale_inv_offsets[i] + if i < self.num_tensors - 1: + cscale_end = self.columnwise_scale_inv_offsets[i + 1] + else: + cscale_end = self.columnwise_scale_inv.numel() + + cscale_shape = self.quantizer.get_scale_shape(tensor_shape, True) + columnwise_scale_inv = self.columnwise_scale_inv[cscale_start:cscale_end].view( + cscale_shape + ) + + if self.quantizer.internal: + mxfp8_tensor_class = MXFP8TensorStorage + else: + mxfp8_tensor_class = MXFP8Tensor + tensor = mxfp8_tensor_class( + shape=tensor_shape, + dtype=self.dtype, + rowwise_data=rowwise_data, + rowwise_scale_inv=rowwise_scale_inv, + columnwise_data=columnwise_data, + columnwise_scale_inv=columnwise_scale_inv, + fp8_dtype=self.quantizer.dtype, + quantizer=self.quantizer, + with_gemm_swizzled_scales=self.quantizer.optimize_for_gemm, + ) + result.append(tensor) + + # Delayed scaling or current scaling (both use Float8TensorStorage) + elif recipe.delayed() or recipe.float8_current_scaling(): + # Scale inverse - one per tensor + scale_inv = None + if self.scale_inv is not None: + scale_inv = self.scale_inv[i : i + 1] + + if self.quantizer.internal: + float8_tensor_class = Float8TensorStorage + else: + float8_tensor_class = Float8Tensor + + tensor = float8_tensor_class( + shape=tensor_shape, + dtype=self.dtype, + data=rowwise_data, + fp8_scale_inv=scale_inv, + fp8_dtype=self.quantizer.dtype, + quantizer=self.quantizer, + data_transpose=columnwise_data, + ) + result.append(tensor) + + # Float8 block scaling + elif recipe.float8_block_scaling(): + # Extract scale_inv data + rowwise_scale_inv = None + columnwise_scale_inv = None + + if self.scale_inv is not None and self.scale_inv_offsets is not None: + scale_start = self.scale_inv_offsets[i] + if i < self.num_tensors - 1: + scale_end = self.scale_inv_offsets[i + 1] + else: + scale_end = self.scale_inv.numel() + + # Get scale shape from quantizer + scale_shape = self.quantizer.get_scale_shape(tensor_shape, False) + rowwise_scale_inv = self.scale_inv[scale_start:scale_end].view(scale_shape) + + if ( + self.columnwise_scale_inv is not None + and self.columnwise_scale_inv_offsets is not None + ): + cscale_start = self.columnwise_scale_inv_offsets[i] + if i < self.num_tensors - 1: + cscale_end = self.columnwise_scale_inv_offsets[i + 1] + else: + cscale_end = self.columnwise_scale_inv.numel() + + # Get columnwise scale shape from quantizer + cscale_shape = self.quantizer.get_scale_shape(tensor_shape, True) + columnwise_scale_inv = self.columnwise_scale_inv[cscale_start:cscale_end].view( + cscale_shape + ) + + # Compute is_2D_scaled and data_format from quantizer attributes + is_2D_scaled = self.quantizer.block_scaling_dim == 2 + + if self.quantizer.internal: + float8_blockwise_q_tensor_class = Float8BlockwiseQTensorStorage + else: + float8_blockwise_q_tensor_class = Float8BlockwiseQTensor + + tensor = float8_blockwise_q_tensor_class( + shape=tensor_shape, + dtype=self.dtype, + rowwise_data=rowwise_data, + rowwise_scale_inv=rowwise_scale_inv, + columnwise_data=columnwise_data, + columnwise_scale_inv=columnwise_scale_inv, + fp8_dtype=self.quantizer.dtype, + quantizer=self.quantizer, + is_2D_scaled=is_2D_scaled, + ) + result.append(tensor) + + # NVFP4 format + elif recipe.nvfp4(): + # Extract scale_inv data + rowwise_scale_inv = None + columnwise_scale_inv = None + amax_rowwise = None + amax_columnwise = None + + if self.scale_inv is not None and self.scale_inv_offsets is not None: + scale_start = self.scale_inv_offsets[i] + if i < self.num_tensors - 1: + scale_end = self.scale_inv_offsets[i + 1] + else: + scale_end = self.scale_inv.numel() + + # Get scale shape from quantizer + scale_shape = self.quantizer.get_scale_shape(tensor_shape, False) + rowwise_scale_inv = self.scale_inv[scale_start:scale_end].view(scale_shape) + + if ( + self.columnwise_scale_inv is not None + and self.columnwise_scale_inv_offsets is not None + ): + cscale_start = self.columnwise_scale_inv_offsets[i] + if i < self.num_tensors - 1: + cscale_end = self.columnwise_scale_inv_offsets[i + 1] + else: + cscale_end = self.columnwise_scale_inv.numel() + + # Get columnwise scale shape from quantizer + cscale_shape = self.quantizer.get_scale_shape(tensor_shape, True) + columnwise_scale_inv = self.columnwise_scale_inv[cscale_start:cscale_end].view( + cscale_shape + ) + + # Extract amax - one per tensor + if self.amax is not None: + amax_rowwise = self.amax[i : i + 1] + + if self.columnwise_amax is not None: + amax_columnwise = self.columnwise_amax[i : i + 1] + + if self.quantizer.internal: + nvfp4_tensor_class = NVFP4TensorStorage + else: + nvfp4_tensor_class = NVFP4Tensor + + tensor = nvfp4_tensor_class( + shape=tensor_shape, + dtype=self.dtype, + rowwise_data=rowwise_data, + rowwise_scale_inv=rowwise_scale_inv, + columnwise_data=columnwise_data, + columnwise_scale_inv=columnwise_scale_inv, + amax_rowwise=amax_rowwise, + amax_columnwise=amax_columnwise, + fp4_dtype=self.quantizer.dtype, + quantizer=self.quantizer, + with_gemm_swizzled_scales=self.quantizer.optimize_for_gemm, + ) + result.append(tensor) + + else: + raise ValueError(f"Unsupported quantization recipe: {recipe}") + + return result + + @staticmethod + def create_and_quantize( + tensors: int, + quantizer: None | Quantizer, + *, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + noop_flag: Optional[torch.Tensor] = None, + ) -> Tuple[QuantizedTensorStorage, ...]: + """ + Quantize given tensors into quantized tensors with underlying + storage allocated in a GroupedTensor. + """ + + grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes( + num_tensors=len(tensors), + shape=[t.shape for t in tensors], + quantizer=quantizer, + device=device, + dtype=dtype, + ) + + grouped_tensor.quantize(tensors, noop_flag=noop_flag) + + return grouped_tensor + + def quantize( + self, + tensors: List[torch.Tensor], + noop_flag: Optional[torch.Tensor] = None, + ) -> Tuple[QuantizedTensorStorage, ...]: + """ + Quantize the GroupedTensor inplace. + """ + + quantized_tensors = self.split_into_quantized_tensors() + for i in range(self.num_tensors): + self.quantizer.update_quantized(tensors[i], quantized_tensors[i], noop_flag=noop_flag) + return quantized_tensors From ffeace8707ee9c09bb6151923d51774480f68dcd Mon Sep 17 00:00:00 2001 From: Varun Thumbe Date: Wed, 4 Feb 2026 17:57:15 +0000 Subject: [PATCH 2/5] grouped gemm support for bf16, bias support missing Signed-off-by: Varun Thumbe --- tests/cpp/operator/test_grouped_gemm.cu | 25 +- tests/pytorch/test_grouped_tensor.py | 56 --- tests/pytorch/test_numerics.py | 215 +++++++++- .../common/gemm/cublaslt_grouped_gemm.cu | 36 +- .../transformer_engine/transformer_engine.h | 209 ++++++++++ .../pytorch/cpp_extensions/gemm.py | 92 +++++ transformer_engine/pytorch/csrc/extensions.h | 5 + .../pytorch/csrc/extensions/gemm.cpp | 66 +++ .../pytorch/csrc/extensions/pybind.cpp | 3 + transformer_engine/pytorch/csrc/pybind.h | 2 + .../pytorch/csrc/type_converters.cpp | 117 +++++- .../pytorch/module/grouped_linear.py | 382 +++++++++++++----- 12 files changed, 1032 insertions(+), 176 deletions(-) diff --git a/tests/cpp/operator/test_grouped_gemm.cu b/tests/cpp/operator/test_grouped_gemm.cu index a694052b15..65d0248b0a 100644 --- a/tests/cpp/operator/test_grouped_gemm.cu +++ b/tests/cpp/operator/test_grouped_gemm.cu @@ -44,8 +44,8 @@ enum class ShapeCase { size_t grouped_setup_workspace_size(const size_t num_tensors) { const size_t ptr_bytes = num_tensors * sizeof(void*); const size_t int_bytes = num_tensors * sizeof(int); - // Layout: 6 pointer arrays (A, B, C, D, alpha, beta) + 6 int arrays (a_rows, a_cols, b_rows, b_cols, d_rows, d_cols) - size_t size = 6 * ptr_bytes + 6 * int_bytes; + // Layout: 8 pointer arrays (A, B, C, D, alpha, beta, a_scale, b_scale) + 6 int arrays + size_t size = 8 * ptr_bytes + 6 * int_bytes; const size_t alignment = 256; size = ((size + alignment - 1) / alignment) * alignment; return size; @@ -88,16 +88,16 @@ struct TestParams { std::vector> make_shapes(ShapeCase scase) { switch (scase) { case ShapeCase::kAllSame: - return {{64, 64, 32}, {64, 64, 32}, {64, 64, 32}}; + return {{128, 256, 384}, {128, 256, 384}, {128, 256, 384}}; case ShapeCase::kSameFirst: // Same M (first dim), varying N and K - return {{64, 80, 32}, {64, 96, 48}, {64, 112, 64}}; + return {{128, 256, 384}, {128, 384, 512}, {128, 512, 640}}; case ShapeCase::kSameLast: // Same N (last dim), varying M and K - return {{64, 80, 32}, {80, 80, 48}, {96, 80, 64}}; + return {{128, 256, 384}, {256, 256, 512}, {384, 256, 640}}; case ShapeCase::kAllDifferent: default: - return {{64, 96, 32}, {80, 112, 48}, {96, 128, 64}}; + return {{128, 256, 384}, {256, 384, 512}, {384, 512, 640}}; } } @@ -123,10 +123,11 @@ void run_grouped_gemm_case(const TestParams& params) { for (size_t i = 0; i < num_gemms; ++i) { const auto [M, N, K] = shapes[i]; - const std::vector a_shape = params.transa ? std::vector{M, K} - : std::vector{K, M}; - const std::vector b_shape = params.transb ? std::vector{K, N} - : std::vector{N, K}; + + const std::vector a_shape = params.transa ? std::vector{N, K} + : std::vector{K, N}; + const std::vector b_shape = params.transb ? std::vector{K, M} + : std::vector{M, K}; switch (params.input_case) { case InputCase::kFP8Current: { A_tensors.emplace_back(make_fp8_operand("A" + std::to_string(i), a_shape)); @@ -247,6 +248,8 @@ void run_grouped_gemm_case(const TestParams& params) { nullptr, // config (use defaults) 0); + NVTE_CHECK_CUDA(cudaDeviceSynchronize()); + // Compare results for (size_t i = 0; i < num_gemms; ++i) { Tensor grouped_split("grouped_D" + std::to_string(i), std::vector{static_cast(std::get<0>(shapes[i])), @@ -288,7 +291,7 @@ std::string MakeGroupedGemmTestName(const testing::TestParamInfo kTestParams = { - // Basic tests + // FP8 tests (each tensor has random mean/stddev -> different scales) {InputCase::kFP8Current, true, false, ShapeCase::kAllDifferent, false}, {InputCase::kFP8Current, false, true, ShapeCase::kAllDifferent, false}, {InputCase::kFP8Current, false, false, ShapeCase::kAllSame, false}, diff --git a/tests/pytorch/test_grouped_tensor.py b/tests/pytorch/test_grouped_tensor.py index c9f1c024c8..318009c669 100644 --- a/tests/pytorch/test_grouped_tensor.py +++ b/tests/pytorch/test_grouped_tensor.py @@ -361,62 +361,6 @@ def test_static_quantize_method(self, quantization: str) -> None: expected_offset = _rowwise_offset_bytes(i * numel, quantization) assert rowwise_data.data_ptr() == original_data_ptr + expected_offset - @pytest.mark.parametrize( - "shape", - [[(256, 512), (512, 512), (768, 512)], [(512, 512), (512, 512), (512, 512)]], - ) - @pytest.mark.skipif(not mxfp8_available, reason=reason_for_no_mxfp8) - def test_quantize_grouped_mxfp8(self, shape: List[Tuple[int, int]]) -> None: - """Test grouped quantization for MXFP8 against per-tensor quantization.""" - # Test wont pass until the grouped quantization PR from Oleg is merged. - num_tensors = 2 - shape = [(512, 1024) for _ in range(num_tensors)] - - # Create BF16 input tensors and pack into a grouped tensor - input_tensors = [torch.randn(s, dtype=torch.bfloat16, device="cuda") for s in shape] - quantized_tensors = [ - MXFP8Quantizer(fp8_dtype=tex.DType.kFloat8E4M3)(tensor) for tensor in input_tensors - ] - grouped_input = GroupedTensor.make_grouped_tensor_with_shapes( - num_tensors=num_tensors, - shape=shape, - quantizer=None, - device="cuda", - dtype=torch.bfloat16, - ) - - offset = 0 - for tensor in input_tensors: - numel = tensor.numel() - grouped_input.data[offset : offset + numel].copy_(tensor.reshape(-1)) - offset += numel - - # Create MXFP8 output grouped tensor (rowwise only for easier validation) - quantizers = [MXFP8Quantizer(fp8_dtype=tex.DType.kFloat8E4M3) for _ in range(num_tensors)] - - grouped_output = GroupedTensor.make_grouped_tensor_with_shapes( - num_tensors=num_tensors, - shape=shape, - quantizer=quantizers, - device="cuda", - ) - - # Quantize using grouped API (handle both 2-arg and 3-arg bindings) - _ = tex.quantize_grouped(grouped_input, grouped_output) - # Build expected output by quantizing each tensor independently - expected_data = [] - expected_scale_inv = [] - for tensor, quantizer in zip(input_tensors, quantizers): - qtensor = quantizer(tensor) - expected_data.append(qtensor._rowwise_data.reshape(-1)) - expected_scale_inv.append(qtensor._rowwise_scale_inv.reshape(-1)) - - expected_data = torch.cat(expected_data) - expected_scale_inv = torch.cat(expected_scale_inv) - - assert torch.equal(grouped_output.data, expected_data) - assert torch.equal(grouped_output.scale_inv, expected_scale_inv) - def test_clear(self) -> None: """Test clear method""" num_tensors = 3 diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index abe2806e66..401dc37a94 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -46,7 +46,12 @@ is_nvfp4_available, ) from transformer_engine.pytorch import checkpoint as te_checkpoint -from transformer_engine.pytorch.cpp_extensions import general_gemm, general_grouped_gemm +from transformer_engine.pytorch.cpp_extensions import ( + general_gemm, + general_grouped_gemm, + general_grouped_gemm_for_grouped_tensor, +) +from transformer_engine.pytorch.tensor.storage.grouped_tensor import GroupedTensor from transformer_engine.common import recipe import transformer_engine_torch as tex from utils import ModelConfig, reset_rng_states @@ -1991,6 +1996,84 @@ def test_grouped_linear_accuracy( torch.testing.assert_close(o, o_ref, rtol=0, atol=0) +@pytest.mark.parametrize("single_weight", [True, False], ids=["single_weight", "multi_weight"]) +def test_grouped_linear_m_splits_tensor(single_weight): + """Test GroupedLinear with m_splits as torch tensor (no_quantization/bf16). + grouped_tensor_path is chosen and must match reference (single_weight vs reference model, + or multi_weight list m_splits vs tensor m_splits). + """ + if tex.get_cublasLt_version() < 130200: + pytest.skip("Grouped GEMM requires cuBLAS 13.2+.") + if torch.cuda.get_device_capability() < (10, 0): + pytest.skip("Grouped GEMM requires Blackwell (SM100) or newer.") + if not is_bf16_available(): + pytest.skip("bfloat16 is required for grouped GEMM test.") + + torch.manual_seed(0) + num_gemms = 3 + in_features = 32 + out_features = 64 + m_splits = torch.tensor([5, 7, 9], device="cuda", dtype=torch.int64) + m_splits_list = [5, 7, 9] + dtype = torch.bfloat16 + m_total = int(m_splits.sum().item()) + + reference = GroupedLinear( + num_gemms, + in_features, + out_features, + bias=False, + params_dtype=dtype, + device="cuda", + single_weight=False, + ).eval() + with torch.no_grad(): + ref_weights = [getattr(reference, f"weight{i}") for i in range(num_gemms)] + + model_under_test = GroupedLinear( + num_gemms, + in_features, + out_features, + bias=False, + params_dtype=dtype, + device="cuda", + single_weight=single_weight, + ).eval() + with torch.no_grad(): + if single_weight: + for i, w in enumerate( + model_under_test.grouped_weight_storage.split_into_quantized_tensors() + ): + w.copy_(ref_weights[i]) + else: + for i in range(num_gemms): + getattr(model_under_test, f"weight{i}").copy_(ref_weights[i]) + + inp = torch.randn(m_total, in_features, device="cuda", dtype=dtype, requires_grad=True) + inp_ref = inp.detach().clone().requires_grad_() + + if single_weight: + out = model_under_test(inp, m_splits) + out_ref = reference(inp_ref, m_splits) + else: + out = model_under_test(inp, m_splits) + out_ref = model_under_test(inp_ref, m_splits_list) + + torch.testing.assert_close(out, out_ref, **dtype_tols(dtype)) + + out.sum().backward() + out_ref.sum().backward() + + torch.testing.assert_close(inp.grad, inp_ref.grad, **dtype_tols(dtype)) + if single_weight: + ref_wgrad = torch.cat( + [getattr(reference, f"weight{i}").grad.view(-1) for i in range(num_gemms)] + ) + torch.testing.assert_close( + getattr(model_under_test, "weight0").grad, ref_wgrad, **dtype_tols(dtype) + ) + + @pytest.mark.skipif( torch.cuda.get_device_capability() != (9, 0), reason="Only enable CUTLASS grouped gemm on Hopper", @@ -2719,10 +2802,15 @@ def test_grouped_gemm(shape, dtype, layout, accumulate, use_cutlass): torch.manual_seed(0) z, m, k, n = shape - dist = torch.sort(torch.randint(0, m, (z - 1,))).values.tolist() - m_splits = torch.tensor(dist + [m]) - torch.tensor([0] + dist) - assert m_splits.sum() == m and len(m_splits) == z - m_splits = m_splits.tolist() + if z == 1: + m_splits = [m] + else: + split_points = torch.randperm(m - 1)[: z - 1] + 1 + split_points = torch.sort(split_points).values.tolist() + m_splits = [split_points[0]] + m_splits += [b - a for a, b in zip(split_points[:-1], split_points[1:])] + m_splits.append(m - split_points[-1]) + assert sum(m_splits) == m and len(m_splits) == z if layout == "TN": A = [torch.randn(n, k, dtype=dtype, device="cuda") for _ in range(z)] # weight @@ -2790,6 +2878,123 @@ def test_grouped_gemm(shape, dtype, layout, accumulate, use_cutlass): os.environ.pop("NVTE_USE_CUTLASS_GROUPED_GEMM", None) +def _pack_grouped_tensor(grouped_tensor: GroupedTensor, tensors: List[torch.Tensor]) -> None: + offset = 0 + for tensor in tensors: + numel = tensor.numel() + grouped_tensor.data[offset : offset + numel].copy_(tensor.reshape(-1)) + offset += numel + + + +@pytest.mark.parametrize("layout", ["TN", "NN", "NT"]) +@pytest.mark.parametrize("accumulate", [False]) +def test_grouped_gemm_grouped_tensor(layout, accumulate): + if tex.get_cublasLt_version() < 130200: + pytest.skip("Grouped GEMM requires cuBLAS 13.2+.") + if torch.cuda.get_device_capability() < (10, 0): + pytest.skip("Grouped GEMM requires Blackwell (SM100) or newer.") + if not is_bf16_available(): + pytest.skip("bfloat16 is required for grouped GEMM test.") + + torch.manual_seed(0) + z, m, k, n = (4, 512, 256, 256) + + split_points = torch.randperm(m - 1)[: z - 1] + 1 + split_points = torch.sort(split_points).values.tolist() + m_sizes = [split_points[0]] + m_sizes += [b - a for a, b in zip(split_points[:-1], split_points[1:])] + m_sizes.append(m - split_points[-1]) + assert sum(m_sizes) == m and len(m_sizes) == z + + dtype = torch.bfloat16 + + if layout == "TN": + A = [torch.randn(n, k, dtype=dtype, device="cuda") for _ in range(z)] # weight + B = [torch.randn(ms, k, dtype=dtype, device="cuda") for ms in m_sizes] # input + out = [torch.randn(ms, n, dtype=dtype, device="cuda") for ms in m_sizes] # output + grad = False + + elif layout == "NN": + A = [torch.randn(n, k, dtype=dtype, device="cuda") for _ in range(z)] # weight + B = [torch.randn(ms, n, dtype=dtype, device="cuda") for ms in m_sizes] # grad_output + out = [torch.randn(ms, k, dtype=dtype, device="cuda") for ms in m_sizes] # dgrad + grad = True + else: # layout == "NT" + A = [torch.randn(ms, k, dtype=dtype, device="cuda") for ms in m_sizes] # input + B = [torch.randn(ms, n, dtype=dtype, device="cuda") for ms in m_sizes] # grad_output + out = [torch.randn(n, k, dtype=dtype, device="cuda") for _ in range(z)] # wgrad + grad = True + + out_ref = [o.clone() for o in out] + general_grouped_gemm( + A, + B, + out_ref, + [None] * z, + dtype, + m_splits=m_sizes, + grad=grad, + accumulate=accumulate, + layout=layout, + single_output=False, + ) + + device = A[0].device + def _make_grouped_tensor_from_splits(m_sizes, last_dim): + first_dims = torch.tensor(m_sizes, device=device, dtype=torch.int64) + return GroupedTensor.make_grouped_tensor( + num_tensors=len(m_sizes), + first_dims=first_dims, + last_dims=None, + logical_first_dim=sum(m_sizes), + logical_last_dim=last_dim, + quantizer=None, + device=device, + dtype=dtype, + ) + + def _make_grouped_tensor_uniform(num_tensors, first_dim, last_dim): + return GroupedTensor.make_grouped_tensor( + num_tensors=num_tensors, + first_dims=None, + last_dims=None, + logical_first_dim=num_tensors * first_dim, + logical_last_dim=last_dim, + quantizer=None, + device=device, + dtype=dtype, + ) + + if layout == "TN": + grouped_A = _make_grouped_tensor_uniform(z, n, k) + grouped_B = _make_grouped_tensor_from_splits(m_sizes, k) + grouped_out = _make_grouped_tensor_from_splits(m_sizes, n) + elif layout == "NN": + grouped_A = _make_grouped_tensor_uniform(z, n, k) + grouped_B = _make_grouped_tensor_from_splits(m_sizes, n) + grouped_out = _make_grouped_tensor_from_splits(m_sizes, k) + else: # layout == "NT" + grouped_A = _make_grouped_tensor_from_splits(m_sizes, k) + grouped_B = _make_grouped_tensor_from_splits(m_sizes, n) + grouped_out = _make_grouped_tensor_uniform(z, n, k) + _pack_grouped_tensor(grouped_A, A) + _pack_grouped_tensor(grouped_B, B) + _pack_grouped_tensor(grouped_out, out) + + general_grouped_gemm_for_grouped_tensor( + grouped_A, + grouped_B, + grouped_out, + layout=layout, + accumulate=accumulate, + ) + + out_grouped = grouped_out.split_into_quantized_tensors() + tols = dtype_tols(dtype) + for o, o_ref in zip(out_grouped, out_ref): + torch.testing.assert_close(o, o_ref, **tols) + @pytest.mark.parametrize("N", [32]) @pytest.mark.parametrize("datatype", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize( diff --git a/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu b/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu index b3e216dc4f..5f33ab2733 100644 --- a/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu @@ -11,6 +11,7 @@ #include #include +#include #include "../common.h" #include "../util/cuda_runtime.h" @@ -118,28 +119,42 @@ struct GroupedGemmSetupWorkspace { int *d_cols; // N (last dim) - also used for C // Initialize from workspace buffer - // Layout: all pointer arrays first (8-byte aligned), then int arrays (4-byte aligned) + // Layout: all pointer arrays first (16-byte aligned for cuBLAS), then int arrays static GroupedGemmSetupWorkspace from_buffers(char *setup_ws_ptr, size_t num_tensors) { GroupedGemmSetupWorkspace ws; size_t offset = 0; const size_t ptr_size = num_tensors * sizeof(void *); const size_t int_size = num_tensors * sizeof(int); - // Pointer arrays first (all 8-byte aligned) + constexpr size_t kPtrAlignment = 16; // cuBLAS requires 16-byte alignment for pointer arrays + + // Helper to align offset to kPtrAlignment + auto align_offset = [&]() { + offset = (offset + kPtrAlignment - 1) / kPtrAlignment * kPtrAlignment; + }; + + // Pointer arrays first (all 16-byte aligned for cuBLAS grouped GEMM) + align_offset(); ws.A_ptrs = reinterpret_cast(setup_ws_ptr + offset); offset += ptr_size; + align_offset(); ws.B_ptrs = reinterpret_cast(setup_ws_ptr + offset); offset += ptr_size; + align_offset(); ws.C_ptrs = reinterpret_cast(setup_ws_ptr + offset); offset += ptr_size; + align_offset(); ws.D_ptrs = reinterpret_cast(setup_ws_ptr + offset); offset += ptr_size; + align_offset(); ws.alpha_ptrs = reinterpret_cast(setup_ws_ptr + offset); offset += ptr_size; + align_offset(); ws.beta_ptrs = reinterpret_cast(setup_ws_ptr + offset); offset += ptr_size; // Int arrays for storage dimensions (4-byte aligned) + align_offset(); ws.a_rows = reinterpret_cast(setup_ws_ptr + offset); offset += int_size; ws.a_cols = reinterpret_cast(setup_ws_ptr + offset); @@ -159,8 +174,11 @@ struct GroupedGemmSetupWorkspace { static size_t required_setup_size(size_t num_tensors, size_t alignment) { const size_t ptr_size = num_tensors * sizeof(void *); const size_t int_size = num_tensors * sizeof(int); - // Layout: 6 ptr arrays, then 6 int arrays - size_t size = 6 * ptr_size + 6 * int_size; + constexpr size_t kPtrAlignment = 16; // Must match from_buffers + // Layout: 8 ptr arrays (each 16-byte aligned), then 6 int arrays + // Each ptr array takes ptr_size bytes but needs to start at 16-byte boundary + auto aligned_ptr_size = ((ptr_size + kPtrAlignment - 1) / kPtrAlignment) * kPtrAlignment; + size_t size = 8 * aligned_ptr_size + 6 * int_size; size = ((size + alignment - 1) / alignment) * alignment; return size; } @@ -383,6 +401,10 @@ inline void init_matmul_desc(cublasLtMatmulDescOpaque_t &matmulDesc, cublasOpera NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(&matmulDesc, CUBLASLT_MATMUL_DESC_BETA_BATCH_STRIDE, &alphabeta_batch_stride, sizeof(int64_t))); + // Fast accumulation mode: 0 = split accumulator (more accurate), 1 = fast accumulator + int8_t fastAccuMode = 0; // Use split accumulator for accuracy + NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(&matmulDesc, CUBLASLT_MATMUL_DESC_FAST_ACCUM, + &fastAccuMode, sizeof(fastAccuMode))); } inline void set_fp8_scale_pointers(cublasLtMatmulDescOpaque_t &matmulDesc, @@ -487,9 +509,9 @@ __global__ void setup_grouped_gemm_kernel( a_cols[idx] = static_cast(a_first); b_rows[idx] = static_cast(b_last); b_cols[idx] = static_cast(b_first); - // For OUTPUTS (D, C): cuBLAS writes in column-major, so rows=first (M), cols=last (N). - d_rows[idx] = static_cast(d_first); - d_cols[idx] = static_cast(d_last); + + d_rows[idx] = static_cast(d_last); + d_cols[idx] = static_cast(d_first); // Fill alpha/beta pointers (per-matrix) alpha_ptrs[idx] = alpha_ptr + idx; diff --git a/transformer_engine/common/include/transformer_engine/transformer_engine.h b/transformer_engine/common/include/transformer_engine/transformer_engine.h index ae41f238a4..1d4186489b 100644 --- a/transformer_engine/common/include/transformer_engine/transformer_engine.h +++ b/transformer_engine/common/include/transformer_engine/transformer_engine.h @@ -957,6 +957,215 @@ class TensorWrapper { NVTETensor tensor_ = nullptr; }; + +/*! \struct GroupedTensorWrapper + * \brief C++ wrapper for the NVTEGroupedTensor class. + */ + + class GroupedTensorWrapper { + public: + /*! \brief Constructs new GroupedTensorWrapper. + * + * Create a new TE grouped tensor with a given logical shape. + * TE grouped tensors are just wrappers on top of raw data and do not + * own memory. + * + * \param[in] num_tensors Number of tensors in the group (must be > 0). + * \param[in] logical_shape Logical 2D shape of the grouped data. + * \param[in] scaling_mode Tensor data format. + */ + GroupedTensorWrapper(const size_t num_tensors, const NVTEShape &logical_shape, + const NVTEScalingMode scaling_mode = NVTE_DELAYED_TENSOR_SCALING) + : tensor_(nvte_create_grouped_tensor(scaling_mode, num_tensors, logical_shape)) {} + + /*! \brief Constructs new GroupedTensorWrapper. + * + * Create a new TE grouped tensor with a given logical shape. + * + * \param[in] num_tensors Number of tensors in the group (must be > 0). + * \param[in] logical_shape Logical 2D shape of the grouped data. + * \param[in] scaling_mode Tensor data format. + */ + GroupedTensorWrapper(const size_t num_tensors, const std::vector &logical_shape, + const NVTEScalingMode scaling_mode = NVTE_DELAYED_TENSOR_SCALING) + : GroupedTensorWrapper(num_tensors, + nvte_make_shape(logical_shape.data(), logical_shape.size()), + scaling_mode) {} + + /*! \brief GroupedTensorWrapper destructor. */ + ~GroupedTensorWrapper() { nvte_destroy_grouped_tensor(tensor_); } + + GroupedTensorWrapper &operator=(const GroupedTensorWrapper &other) = delete; + GroupedTensorWrapper(const GroupedTensorWrapper &other) = delete; + + /*! \brief Constructs new GroupedTensorWrapper from existing GroupedTensorWrapper. */ + GroupedTensorWrapper(GroupedTensorWrapper &&other) { + tensor_ = other.tensor_; + other.tensor_ = nullptr; + } + + /*! \brief Assign the data from existing GroupedTensorWrapper. */ + GroupedTensorWrapper &operator=(GroupedTensorWrapper &&other) { + if (this == &other) return *this; + nvte_destroy_grouped_tensor(tensor_); + tensor_ = other.tensor_; + other.tensor_ = nullptr; + return *this; + } + + // Parameter setters + template + GroupedTensorWrapper &set_parameter(const NVTEGroupedTensorParam param, void *dptr, DType type, + const ShapeType &shape) noexcept { + NVTEShape nvte_shape = this->convertShape(shape); + NVTEBasicTensor data = {dptr, static_cast(type), nvte_shape}; + nvte_set_grouped_tensor_param(&tensor_, param, &data); + return *this; + } + + template + GroupedTensorWrapper &set_rowwise_data(void *dptr, DType type, const ShapeType &shape) noexcept { + return set_parameter(kNVTEGroupedRowwiseData, dptr, type, shape); + } + + template + GroupedTensorWrapper &set_columnwise_data(void *dptr, DType type, + const ShapeType &shape) noexcept { + return set_parameter(kNVTEGroupedColumnwiseData, dptr, type, shape); + } + + template + GroupedTensorWrapper &set_scale(void *dptr, DType type, const ShapeType &shape) noexcept { + return set_parameter(kNVTEGroupedScale, dptr, type, shape); + } + + template + GroupedTensorWrapper &set_amax(void *dptr, DType type, const ShapeType &shape) noexcept { + return set_parameter(kNVTEGroupedAmax, dptr, type, shape); + } + + template + GroupedTensorWrapper &set_rowwise_scale_inv(void *dptr, DType type, + const ShapeType &shape) noexcept { + return set_parameter(kNVTEGroupedRowwiseScaleInv, dptr, type, shape); + } + + template + GroupedTensorWrapper &set_columnwise_scale_inv(void *dptr, DType type, + const ShapeType &shape) noexcept { + return set_parameter(kNVTEGroupedColumnwiseScaleInv, dptr, type, shape); + } + + template + GroupedTensorWrapper &set_columnwise_amax(void *dptr, DType type, + const ShapeType &shape) noexcept { + return set_parameter(kNVTEGroupedColumnwiseAmax, dptr, type, shape); + } + + template + GroupedTensorWrapper &set_first_dims(void *dptr, DType type, const ShapeType &shape) noexcept { + return set_parameter(kNVTEGroupedFirstDims, dptr, type, shape); + } + + template + GroupedTensorWrapper &set_last_dims(void *dptr, DType type, const ShapeType &shape) noexcept { + return set_parameter(kNVTEGroupedLastDims, dptr, type, shape); + } + + template + GroupedTensorWrapper &set_tensor_offsets(void *dptr, DType type, + const ShapeType &shape) noexcept { + return set_parameter(kNVTEGroupedTensorOffsets, dptr, type, shape); + } + + // Parameter getters + NVTEBasicTensor get_parameter(const NVTEGroupedTensorParam param) const noexcept { + return nvte_get_grouped_tensor_param(tensor_, param); + } + + NVTEBasicTensor get_rowwise_data() const noexcept { + return get_parameter(kNVTEGroupedRowwiseData); + } + + NVTEBasicTensor get_columnwise_data() const noexcept { + return get_parameter(kNVTEGroupedColumnwiseData); + } + + NVTEBasicTensor get_scale() const noexcept { return get_parameter(kNVTEGroupedScale); } + + NVTEBasicTensor get_amax() const noexcept { return get_parameter(kNVTEGroupedAmax); } + + NVTEBasicTensor get_rowwise_scale_inv() const noexcept { + return get_parameter(kNVTEGroupedRowwiseScaleInv); + } + + NVTEBasicTensor get_columnwise_scale_inv() const noexcept { + return get_parameter(kNVTEGroupedColumnwiseScaleInv); + } + + NVTEBasicTensor get_columnwise_amax() const noexcept { + return get_parameter(kNVTEGroupedColumnwiseAmax); + } + + NVTEBasicTensor get_first_dims() const noexcept { return get_parameter(kNVTEGroupedFirstDims); } + + NVTEBasicTensor get_last_dims() const noexcept { return get_parameter(kNVTEGroupedLastDims); } + + NVTEBasicTensor get_tensor_offsets() const noexcept { + return get_parameter(kNVTEGroupedTensorOffsets); + } + + /*! \brief Get an underlying NVTEGroupedTensor. + * + * \return NVTEGroupedTensor held by this GroupedTensorWrapper. + */ + NVTEGroupedTensor data() const noexcept { return tensor_; } + + /*! \brief Get the number of tensors in this GroupedTensorWrapper. */ + size_t num_tensors() const noexcept { + if (tensor_ == nullptr) return 0; + return nvte_grouped_tensor_num_tensors(tensor_); + } + + /*! \brief Get the data type of this GroupedTensorWrapper. */ + DType dtype() const noexcept { + if (tensor_ == nullptr) return DType::kNumTypes; + return static_cast(nvte_grouped_tensor_type(tensor_)); + } + + /*! \brief Get a scaling mode of the grouped tensor. */ + NVTEScalingMode scaling_mode() const noexcept { + if (tensor_ == nullptr) return NVTE_DELAYED_TENSOR_SCALING; + return nvte_grouped_tensor_scaling_mode(tensor_); + } + + /*! \brief Get the logical shape of this GroupedTensorWrapper. */ + const NVTEShape logical_shape() const noexcept { + if (tensor_ == nullptr) { + return emptyShape; + } + return nvte_get_grouped_tensor_logical_shape(tensor_); + } + + static constexpr size_t defaultData = 1; + static constexpr NVTEShape defaultShape = { + {defaultData, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, 1}; + static constexpr NVTEShape emptyShape = {{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, 1}; + + private: + NVTEShape convertShape(const NVTEShape &s) { return s; } + + NVTEShape convertShape(const std::vector &s) { + return nvte_make_shape(s.data(), s.size()); + } + + /*! \brief Wrapped NVTEGroupedTensor. */ + NVTEGroupedTensor tensor_ = nullptr; + }; + + /*! \enum Float8BlockScaleTensorFormat + * \brief Data format for an FP8 block-scaled tensor + */ /*! \warning Deprecated */ enum class Float8BlockScaleTensorFormat { GEMM_READY = 0, COMPACT = 1, INVALID }; diff --git a/transformer_engine/pytorch/cpp_extensions/gemm.py b/transformer_engine/pytorch/cpp_extensions/gemm.py index 406e7075f7..631569ac91 100644 --- a/transformer_engine/pytorch/cpp_extensions/gemm.py +++ b/transformer_engine/pytorch/cpp_extensions/gemm.py @@ -5,6 +5,7 @@ """Python interface for GEMM extensions""" from typing import Iterable, Optional, Tuple, Union, List +import ctypes import os import functools import torch @@ -14,6 +15,7 @@ from ..quantized_tensor import Quantizer from ..tensor.storage.float8_blockwise_tensor_storage import Float8BlockwiseQTensorStorage +from ..tensor.storage.grouped_tensor import GroupedTensor from ..tensor.utils import is_custom from ..custom_recipes.gemm import custom_gemm from ...debug.pytorch.debug_quantization import DebugQuantizer @@ -22,6 +24,7 @@ __all__ = [ "general_gemm", "general_grouped_gemm", + "general_grouped_gemm_for_grouped_tensor", ] @@ -306,3 +309,92 @@ def general_grouped_gemm( ) return out, bias, gelu_input + +def get_grouped_gemm_setup_workspace_size(num_tensors: int) -> int: + """Return workspace size for grouped GEMM pointer setup. + Must match GroupedGemmSetupWorkspace::required_setup_size in cublaslt_grouped_gemm.cu. + """ + ptr_bytes = ctypes.sizeof(ctypes.c_void_p) + int_bytes = ctypes.sizeof(ctypes.c_int) + ptr_size = num_tensors * ptr_bytes + int_size = num_tensors * int_bytes + k_ptr_alignment = 16 + aligned_ptr_size = ((ptr_size + k_ptr_alignment - 1) // k_ptr_alignment) * k_ptr_alignment + size = 8 * aligned_ptr_size + 6 * int_size + alignment = 256 + return ((size + alignment - 1) // alignment) * alignment + + +def general_grouped_gemm_for_grouped_tensor( + A, + B, + out, + *, + layout: str = "TN", + accumulate: bool = False, + alpha: Optional[torch.Tensor] = None, + beta: Optional[torch.Tensor] = None, +) -> torch.Tensor: + """ + Grouped GEMM using GroupedTensor inputs. + + This uses nvte_grouped_gemm and supports different per-matrix shapes. + + The caller must ensure that GroupedTensor metadata is already compatible with the + underlying GEMM implementation (e.g., aligned offsets and output metadata layout). + """ + assert layout in ("TN", "NN", "NT"), f"GEMM layout {layout} not supported." + transa = layout[0] == "T" + transb = layout[1] == "T" + + num_tensors = A.num_tensors + assert A.num_tensors == B.num_tensors == out.num_tensors, ( + f"GroupedTensor num_tensors must match: A={A.num_tensors}, B={B.num_tensors}, out={out.num_tensors}" + ) + + if out.data is not None: + device = out.data.device + elif out.columnwise_data is not None: + device = out.columnwise_data.device + else: + raise ValueError("Output GroupedTensor must have allocated data.") + + if alpha is None: + alpha = torch.ones(num_tensors, dtype=torch.float32, device=device) + if beta is None: + if accumulate: + beta = torch.ones(num_tensors, dtype=torch.float32, device=device) + else: + beta = torch.zeros(num_tensors, dtype=torch.float32, device=device) + + if not alpha.is_cuda or not beta.is_cuda: + raise ValueError("alpha and beta must be CUDA tensors.") + + workspace_setup = torch.empty( + get_grouped_gemm_setup_workspace_size(num_tensors), + dtype=torch.uint8, + device=device, + ) + workspace_cublas = torch.empty( + get_cublas_workspace_size_bytes(), + dtype=torch.uint8, + device=device, + ) + + sm_count = get_sm_count() + sm_count = sm_count - int(os.getenv("NVTE_EXT_MARGIN_SM", str(sm_count))) + + C = out + return tex.te_general_grouped_gemm_for_grouped_tensor( + A, + transa, + B, + transb, + C, + out, + alpha, + beta, + workspace_setup, + workspace_cublas, + sm_count, + ) diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index f7cf32eaf6..fa7e27ed00 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -150,6 +150,11 @@ std::optional> te_general_grouped_gemm( std::vector pre_gelu_out, bool grad, std::vector workspace, size_t workspaceSize, bool accumulate, bool use_split_accumulator, int math_sm_count); +py::object te_general_grouped_gemm_for_grouped_tensor( + py::handle A, bool transa, py::handle B, bool transb, py::object C, py::handle D, + at::Tensor alpha, at::Tensor beta, at::Tensor workspace_setup, at::Tensor workspace_cublas, + int math_sm_count); + /*************************************************************************************************** * Transpose **************************************************************************************************/ diff --git a/transformer_engine/pytorch/csrc/extensions/gemm.cpp b/transformer_engine/pytorch/csrc/extensions/gemm.cpp index d75b0f14c7..765f346bb2 100644 --- a/transformer_engine/pytorch/csrc/extensions/gemm.cpp +++ b/transformer_engine/pytorch/csrc/extensions/gemm.cpp @@ -570,4 +570,70 @@ std::optional> te_general_grouped_gemm( return bias; } +py::object te_general_grouped_gemm_for_grouped_tensor( + py::handle A, bool transa, py::handle B, bool transb, py::object C, py::handle D, + at::Tensor alpha, at::Tensor beta, at::Tensor workspace_setup, at::Tensor workspace_cublas, + int math_sm_count) { + using namespace transformer_engine::pytorch::detail; + + init_extension(); + + // Ensure that cublasLt handle is created on the correct device, + // overriding torch.cuda.set_device calls from user side. + // Assumes all tensors passed are on the same device. + at::cuda::CUDAGuard device_guard(workspace_cublas.device()); + + auto grouped_A = GroupedTensorFromPyTorchGroupedTensor(A); + auto grouped_B = GroupedTensorFromPyTorchGroupedTensor(B); + auto grouped_D = GroupedTensorFromPyTorchGroupedTensor(D); + + std::optional grouped_C = std::nullopt; + if (!C.is_none()) { + grouped_C = GroupedTensorFromPyTorchGroupedTensor(C); + } + + const size_t num_tensors = grouped_A.num_tensors(); + NVTE_CHECK(num_tensors > 0, "Grouped GEMM requires non-empty inputs."); + NVTE_CHECK(grouped_B.num_tensors() == num_tensors, + "Grouped GEMM requires A and B to have the same num_tensors."); + NVTE_CHECK(grouped_D.num_tensors() == num_tensors, + "Grouped GEMM requires D to have the same num_tensors as inputs."); + if (grouped_C.has_value()) { + NVTE_CHECK(grouped_C->num_tensors() == num_tensors, + "Grouped GEMM requires C to have the same num_tensors as inputs."); + } + + NVTE_CHECK(alpha.numel() == static_cast(num_tensors), + "Grouped GEMM expects alpha to have num_tensors elements."); + NVTE_CHECK(beta.numel() == static_cast(num_tensors), + "Grouped GEMM expects beta to have num_tensors elements."); + + auto te_alpha = makeTransformerEngineTensor(alpha); + auto te_beta = makeTransformerEngineTensor(beta); + + auto te_workspace_setup = makeTransformerEngineTensor( + workspace_setup.data_ptr(), std::vector{static_cast(workspace_setup.numel())}, + DType::kByte); + auto te_workspace_cublas = makeTransformerEngineTensor( + workspace_cublas.data_ptr(), + std::vector{static_cast(workspace_cublas.numel())}, DType::kByte); + + std::optional config; + if (math_sm_count > 0) { + config.emplace(); + config->set_sm_count(math_sm_count); + } + + NVTE_SCOPED_GIL_RELEASE({ + nvte_grouped_gemm(grouped_A.data(), transa, grouped_B.data(), transb, + grouped_C.has_value() ? grouped_C->data() : nullptr, grouped_D.data(), + te_alpha.data(), te_beta.data(), te_workspace_setup.data(), + te_workspace_cublas.data(), + config.has_value() ? static_cast(*config) : nullptr, + at::cuda::getCurrentCUDAStream()); + }); + + return py::reinterpret_borrow(D); +} + } // namespace transformer_engine::pytorch diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index 79dd9ea5ce..c5f78308ea 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -251,6 +251,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { py::arg("quantizer_list"), py::arg("disable_bulk_allocation") = false); m.def("te_general_grouped_gemm", &transformer_engine::pytorch::te_general_grouped_gemm, "Grouped GEMM"); + m.def("te_general_grouped_gemm_for_grouped_tensor", + &transformer_engine::pytorch::te_general_grouped_gemm_for_grouped_tensor, + "Grouped GEMM for GroupedTensor"); m.def("fp8_transpose", &transformer_engine::pytorch::fp8_transpose, "Transpose with FP8 I/O", py::arg("input"), py::arg("dtype"), py::kw_only(), py::arg("out"), py::call_guard()); diff --git a/transformer_engine/pytorch/csrc/pybind.h b/transformer_engine/pytorch/csrc/pybind.h index 25ffef0588..9541409c0c 100644 --- a/transformer_engine/pytorch/csrc/pybind.h +++ b/transformer_engine/pytorch/csrc/pybind.h @@ -95,6 +95,8 @@ TensorWrapper NVTETensorFromFloat8BlockwiseQTensor(py::handle tensor, TensorWrapper NVTETensorFromNVFP4Tensor(py::handle tensor, Quantizer *quantizer); +GroupedTensorWrapper GroupedTensorFromPyTorchGroupedTensor(py::handle tensor); + inline bool IsFloatingPointType(at::ScalarType type) { return type == at::kFloat || type == at::kHalf || type == at::kBFloat16; } diff --git a/transformer_engine/pytorch/csrc/type_converters.cpp b/transformer_engine/pytorch/csrc/type_converters.cpp index 3f998bb66f..e6e392b894 100644 --- a/transformer_engine/pytorch/csrc/type_converters.cpp +++ b/transformer_engine/pytorch/csrc/type_converters.cpp @@ -170,6 +170,121 @@ TensorWrapper NVTETensorFromNVFP4Tensor(py::handle tensor, Quantizer *quantizer) return ret; } +NVTEScalingMode ScalingModeFromQuantizer(py::handle quantizer) { + auto *quantizer_ptr = quantizer.ptr(); + if (IsMXFP8Quantizers(quantizer_ptr)) { + return NVTE_MXFP8_1D_SCALING; + } + if (IsNVFP4Quantizers(quantizer_ptr)) { + return NVTE_NVFP4_1D_SCALING; + } + if (IsFloat8BlockwiseQuantizers(quantizer_ptr)) { + const int block_scaling_dim = quantizer.attr("block_scaling_dim").cast(); + return (block_scaling_dim == 2) ? NVTE_BLOCK_SCALING_2D : NVTE_BLOCK_SCALING_1D; + } + return NVTE_DELAYED_TENSOR_SCALING; +} + +DType GetTransformerEngineDTypeForScaleInv(py::handle quantizer, at::Tensor scale_inv) { + auto *quantizer_ptr = quantizer.ptr(); + if (IsMXFP8Quantizers(quantizer_ptr)) { + return DType::kFloat8E8M0; + } + if (IsFloat8BlockwiseQuantizers(quantizer_ptr)) { + return DType::kFloat32; + } + if (IsNVFP4Quantizers(quantizer_ptr)) { + return DType::kFloat8E4M3; + } + return GetTransformerEngineDType(scale_inv.scalar_type()); +} + + + +GroupedTensorWrapper GroupedTensorFromPyTorchGroupedTensor(py::handle tensor) { + // Returns a GroupedTensorWrapper from a PyTorch GroupedTensor. + const auto num_tensors = tensor.attr("num_tensors").cast(); + const auto logical_shape = tensor.attr("logical_shape").cast>(); + py::handle quantizer = py::none(); + DType quantizer_dtype = DType::kNumTypes; + NVTEScalingMode scaling_mode = NVTE_DELAYED_TENSOR_SCALING; + if (!tensor.attr("quantizer").is_none()) { + quantizer = tensor.attr("quantizer").cast(); + scaling_mode = ScalingModeFromQuantizer(quantizer); + quantizer_dtype = quantizer.attr("dtype").cast(); + } + auto ret = GroupedTensorWrapper(num_tensors, logical_shape, scaling_mode); + + // Rowwise data + if (!tensor.attr("data").is_none()) { + const auto &data = tensor.attr("data").cast(); + DType data_dtype = quantizer.is_none() ? GetTransformerEngineDType(data.scalar_type()) : quantizer_dtype; + ret.set_rowwise_data(data.data_ptr(), data_dtype, + getTensorShape(data)); + } + + // Columnwise data + if (!tensor.attr("columnwise_data").is_none()) { + const auto &data = tensor.attr("columnwise_data").cast(); + DType data_dtype = quantizer.is_none() ? GetTransformerEngineDType(data.scalar_type()) : quantizer_dtype; + ret.set_columnwise_data(data.data_ptr(), data_dtype, + getTensorShape(data)); + } + + // Scale + if (!tensor.attr("scale").is_none()) { + const auto &scale = tensor.attr("scale").cast(); + ret.set_scale(scale.data_ptr(), GetTransformerEngineDType(scale.scalar_type()), + getTensorShape(scale)); + } + + // Amax + if (!tensor.attr("amax").is_none()) { + const auto &amax = tensor.attr("amax").cast(); + ret.set_amax(amax.data_ptr(), GetTransformerEngineDType(amax.scalar_type()), + getTensorShape(amax)); + } + if (!tensor.attr("columnwise_amax").is_none()) { + const auto &amax = tensor.attr("columnwise_amax").cast(); + ret.set_columnwise_amax(amax.data_ptr(), GetTransformerEngineDType(amax.scalar_type()), + getTensorShape(amax)); + } + + // Scale inverse + if (!tensor.attr("scale_inv").is_none()) { + const auto &scale_inv = tensor.attr("scale_inv").cast(); + ret.set_rowwise_scale_inv(scale_inv.data_ptr(), + GetTransformerEngineDTypeForScaleInv(quantizer, scale_inv), + getTensorShape(scale_inv)); + } + if (!tensor.attr("columnwise_scale_inv").is_none()) { + const auto &scale_inv = tensor.attr("columnwise_scale_inv").cast(); + ret.set_columnwise_scale_inv(scale_inv.data_ptr(), + GetTransformerEngineDTypeForScaleInv(quantizer, scale_inv), + getTensorShape(scale_inv)); + } + + // Shape metadata + if (!tensor.attr("first_dims").is_none()) { + const auto &first_dims = tensor.attr("first_dims").cast(); + ret.set_first_dims(first_dims.data_ptr(), GetTransformerEngineDType(first_dims.scalar_type()), + getTensorShape(first_dims)); + } + if (!tensor.attr("last_dims").is_none()) { + const auto &last_dims = tensor.attr("last_dims").cast(); + ret.set_last_dims(last_dims.data_ptr(), GetTransformerEngineDType(last_dims.scalar_type()), + getTensorShape(last_dims)); +} + if (!tensor.attr("tensor_offsets").is_none()) { + const auto &tensor_offsets = tensor.attr("tensor_offsets").cast(); + ret.set_tensor_offsets(tensor_offsets.data_ptr(), + GetTransformerEngineDType(tensor_offsets.scalar_type()), + getTensorShape(tensor_offsets)); + } + + return ret; +} + } // namespace detail -} // namespace transformer_engine::pytorch +} // namespace transformer_engine::pytorch \ No newline at end of file diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index 1709bf1b37..2dc9794d0f 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -3,8 +3,12 @@ # See LICENSE for license information. """GroupedLinear API""" -from typing import Union, Optional, Callable, Tuple, List +from transformer_engine.pytorch.quantized_tensor import QuantizedTensorStorage +from torch._tensor import Tensor +from typing import Any, Union, Optional, Callable, Tuple, List from itertools import chain +from torch.distributed.tensor import DTensor + import warnings import functools @@ -39,6 +43,7 @@ ) from ..cpp_extensions import ( general_grouped_gemm, + general_grouped_gemm_for_grouped_tensor, ) from ..constants import GemmParallelModes, dist_group_type from ..jit import no_torch_dynamo @@ -57,6 +62,50 @@ __all__ = ["GroupedLinear"] +def _clone_grouped_tensor_with_data( + grouped_tensor: GroupedTensor, data: torch.Tensor, dtype: torch.dtype +) -> GroupedTensor: + return GroupedTensor( + num_tensors=grouped_tensor.num_tensors, + shape=grouped_tensor.shape, + quantizer=grouped_tensor.quantizer, + dtype=dtype, + data=data, + columnwise_data=grouped_tensor.columnwise_data, + scale_inv=grouped_tensor.scale_inv, + columnwise_scale_inv=grouped_tensor.columnwise_scale_inv, + amax=grouped_tensor.amax, + columnwise_amax=grouped_tensor.columnwise_amax, + scale=grouped_tensor.scale, + first_dims=grouped_tensor.first_dims, + last_dims=grouped_tensor.last_dims, + tensor_offsets=grouped_tensor.tensor_offsets, + offsets=grouped_tensor.offsets, + scale_inv_offsets=grouped_tensor.scale_inv_offsets, + columnwise_scale_inv_offsets=grouped_tensor.columnwise_scale_inv_offsets, + logical_shape=grouped_tensor.logical_shape, + ) + + +def _make_grouped_tensor_for_m_splits( + data: torch.Tensor, m_splits: torch.Tensor +) -> GroupedTensor: + # Use data.shape[0] to avoid first_dims.sum().item() D2H copy (breaks CUDA graph) + logical_first_dim = data.shape[0] + grouped = GroupedTensor.make_grouped_tensor( + num_tensors=int(m_splits.numel()), + first_dims=m_splits, + last_dims=None, + logical_first_dim=logical_first_dim, + logical_last_dim=data.shape[-1], + quantizer=None, + device=data.device, + dtype=data.dtype, + ) + grouped.data = data.contiguous().view(-1) + return grouped + + class _GroupedLinear(torch.autograd.Function): """GroupedLinear semi-top level module Calls custom cuda extensions. @@ -76,6 +125,7 @@ def forward( # to reduce CPU overhead due to pytorch arg checking. ( m_splits, + m_splits_is_tensor, use_bias, is_first_microbatch, fp8, @@ -97,10 +147,11 @@ def forward( save_original_input, debug, ) = non_tensor_args - - num_gemms = len(m_splits) - weights = weights_and_biases[:num_gemms] - biases = weights_and_biases[num_gemms:] + num_weight_params = module.num_weight_params + num_gemms = int(m_splits.numel()) if m_splits_is_tensor else len(m_splits) + logical_first_dim = inp.shape[0] if m_splits_is_tensor else sum(m_splits) + weights = weights_and_biases[:num_weight_params] + biases = weights_and_biases[num_weight_params:] device = inp.device weight_requires_grad = weights[0].requires_grad @@ -133,9 +184,11 @@ def forward( if output_quantizers[0] is not None: for output_quantizer in output_quantizers: output_quantizer.set_usage(rowwise=True, columnwise=False) + no_quantization = (not fp8 and weight_quantizers[0] is None) # Initialize input tensors - in_features = weights[0].size(-1) + in_features = module.in_features + out_features = module.out_features if inp.size(-1) != in_features: raise ValueError( f"Input tensor (shape={tuple(inp.size())}) is not compatible with " @@ -143,6 +196,12 @@ def forward( ) inp_view = inp.reshape(-1, in_features) inputmats: list + inp_view_cast = None + if m_splits_is_tensor and not no_quantization: + # TODO: Support this path. + raise ValueError("GroupedGEMM with grouped tensor path with quantization is not supported yet.") + grouped_tensor_path = no_quantization and m_splits_is_tensor + if fp8 and not debug: # Disable bulk allocation when CPU offloading is active: offloading skips small # tensors (like scales), but bulk allocation shares storage across all tensors, @@ -158,7 +217,8 @@ def forward( inp_view, input_quantizers, m_splits, activation_dtype ) else: - inputmats = torch.split(cast_if_needed(inp_view, activation_dtype), m_splits) + inp_view_cast = cast_if_needed(inp_view, activation_dtype) + inputmats = [inp_view_cast] if grouped_tensor_path else torch.split(inp_view_cast, m_splits) if cpu_offloading: start_offload(*inputmats) @@ -169,7 +229,7 @@ def forward( # FP8 cast to workspace buffer weights_fp8 = [] update_workspace = is_first_microbatch is None or is_first_microbatch - for i in range(num_gemms): + for i in range(num_weight_params): weight_fp8 = module.get_weight_workspace( tensor=weights[i], quantizer=weight_quantizers[i], @@ -190,7 +250,7 @@ def forward( biases = [cast_if_needed(bias, bias_dtype) for bias in biases] if use_bias else biases # Initialize output tensor out = torch.empty( - [sum(m_splits), weights_fp8[0].size(0)], + [logical_first_dim, out_features], dtype=activation_dtype, device=device, ) @@ -202,19 +262,35 @@ def forward( if hasattr(recipe, "fp8_gemm_fprop"): use_split_accumulator = recipe.fp8_gemm_fprop.use_split_accumulator - # Perform GEMM - general_grouped_gemm( - weights_fp8, - inputmats, - [out], - output_quantizers, - activation_dtype, - single_output=True, - m_splits=m_splits, - bias=biases, - use_bias=use_bias, - use_split_accumulator=use_split_accumulator, - ) + if grouped_tensor_path: + grouped_weight = _clone_grouped_tensor_with_data( + module.grouped_weight_storage, + cast_if_needed(module.grouped_weight_storage.data, activation_dtype), + activation_dtype, + ) + grouped_input = _make_grouped_tensor_for_m_splits(inputmats[0], m_splits) + grouped_out = _make_grouped_tensor_for_m_splits(out, m_splits) + general_grouped_gemm_for_grouped_tensor( + grouped_weight, + grouped_input, + grouped_out, + layout="TN", + accumulate=False, + ) + else: + # Perform GEMM + general_grouped_gemm( + weights_fp8, + inputmats, + [out], + output_quantizers, + activation_dtype, + single_output=True, + m_splits=m_splits, + bias=biases, + use_bias=use_bias, + use_split_accumulator=use_split_accumulator, + ) if fp8_calibration: for i in range(num_gemms): @@ -229,7 +305,10 @@ def forward( if is_grad_enabled: ctx.weight_quantizers = weight_quantizers - ctx.weights_shape_1 = weights[0].shape[1] + if module.single_weight: + ctx.weights_shape_1 = module.in_features + else: + ctx.weights_shape_1 = weights[0].shape[1] # TODO: update after #1638 is merged. # pylint: disable=fixme if weight_requires_grad: @@ -264,7 +343,6 @@ def forward( ) ctx.save_for_backward(*tensors_to_save) ctx.tensor_objects = tensor_objects - ctx.grad_input_quantizers = grad_input_quantizers ctx.grad_output_quantizers = grad_output_quantizers ctx.grad_weight_quantizers = grad_weight_quantizers @@ -276,17 +354,20 @@ def forward( # the main_grad buffer lazily before backprop if hasattr(weights[0], "__fsdp_param__"): # MCore FSDP creates main_grad lazily before backward - ctx.main_grad_funcs = [weights[i].get_main_grad for i in range(num_gemms)] + ctx.main_grad_funcs = [weights[i].get_main_grad for i in range(num_weight_params)] else: ctx.main_grad_funcs = [ - lambda j=i: weights[j].main_grad for i in range(num_gemms) + lambda j=i: weights[j].main_grad for i in range(num_weight_params) ] else: - ctx.main_grad_funcs = [lambda: None for i in range(num_gemms)] + ctx.main_grad_funcs = [lambda: None for i in range(num_weight_params)] ctx.device = device ctx.output_quantizers = output_quantizers ctx.m_splits = m_splits + ctx.logical_first_dim = logical_first_dim + ctx.grouped_tensor_path = grouped_tensor_path ctx.num_gemms = num_gemms + ctx.num_weight_params = num_weight_params ctx.activation_dtype = activation_dtype ctx.fp8 = fp8 ctx.fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() if fp8 else None @@ -295,6 +376,8 @@ def forward( ctx.is_first_microbatch = is_first_microbatch ctx.use_bias = use_bias ctx.sequence_parallel = sequence_parallel + ctx.in_features = module.in_features + ctx.out_features = module.out_features ctx.inp_shape = inp.shape ctx.requires_dgrad = inp.requires_grad ctx.reduce_and_update_bwd_fp8_tensors = False @@ -307,7 +390,10 @@ def forward( ctx.debug = debug ctx.save_original_input = save_original_input ctx.input_quantizers = input_quantizers - + ctx.single_weight = module.single_weight + ctx.grouped_weight_storage = ( + module.grouped_weight_storage if grouped_tensor_path else None + ) # [*, in_features] -> [*, out_features] except first dimension changes for SP return out.view(-1, *inp.shape[1:-1], out.shape[-1]) @@ -315,8 +401,9 @@ def forward( def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]: # pylint: disable=missing-function-docstring with get_nvtx_range_context("_GroupedLinear_backward"): + m_splits = ctx.m_splits saved_tensors = restore_from_saved(ctx.tensor_objects, ctx.saved_tensors) - N = ctx.num_gemms + N = ctx.num_weight_params inputmats = saved_tensors[:N] weights = saved_tensors[N : 2 * N] origin_weights = saved_tensors[2 * N : 3 * N] @@ -366,7 +453,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ) elif ctx.debug: grad_output_mats = torch.split(grad_output_view, ctx.m_splits) - for i in range(ctx.num_gemms): + for i in range(ctx.num_weight_params): grad_biases[i] = grad_output_mats[i].sum(dim=0) grad_output = DebugQuantizer.multi_tensor_quantize( grad_output_view, @@ -377,10 +464,17 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], else: # Only split grad output. Grad bias is fused with # wgrad GEMM. - grad_output = torch.split( - cast_if_needed(grad_output_view, ctx.activation_dtype), - ctx.m_splits, - ) + if ctx.grouped_tensor_path: + out = cast_if_needed(grad_output_view, ctx.activation_dtype) + grad_output = [out] + grouped_grad_output = _make_grouped_tensor_for_m_splits( + out, m_splits + ) + else: + grad_output = torch.split( + cast_if_needed(grad_output_view, ctx.activation_dtype), + ctx.m_splits, + ) if ctx.is_first_microbatch is not None: accumulate_wgrad_into_param_main_grad = ( @@ -398,27 +492,39 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], recipe.fp8_gemm_dgrad.use_split_accumulator ) dgrad = torch.empty( - (sum(ctx.m_splits), ctx.weights_shape_1), + ctx.inp_shape, dtype=ctx.activation_dtype, device=ctx.device, ) + # Make sure weights are available in column-wise format # for dgrad computation. for weight in weights: if isinstance(weight, QuantizedTensorStorage): weight.update_usage(columnwise_usage=True) - general_grouped_gemm( - weights, - grad_output, - [dgrad], - ctx.grad_input_quantizers, - ctx.activation_dtype, - single_output=True, - layout="NN", - m_splits=ctx.m_splits, - grad=True, - use_split_accumulator=dgrad_gemm_use_split_accumulator, - ) + if ctx.grouped_tensor_path: + grouped_weight = ctx.grouped_weight_storage + grouped_dgrad = _make_grouped_tensor_for_m_splits(dgrad, m_splits) + general_grouped_gemm_for_grouped_tensor( + grouped_weight, + grouped_grad_output, + grouped_dgrad, + layout="NN", + accumulate=False, + ) + else: + general_grouped_gemm( + weights, + grad_output, + [dgrad], + ctx.grad_input_quantizers, + ctx.activation_dtype, + single_output=True, + layout="NN", + m_splits=ctx.m_splits, + grad=True, + use_split_accumulator=dgrad_gemm_use_split_accumulator, + ) if ctx.weights_requires_grad: wgrad_gemm_use_split_accumulator = _2X_ACC_WGRAD @@ -428,7 +534,24 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], wgrad_gemm_use_split_accumulator = ( recipe.fp8_gemm_wgrad.use_split_accumulator ) - if ctx.fuse_wgrad_accumulation: + grouped_wgrad = None + if ctx.grouped_tensor_path and ctx.fuse_wgrad_accumulation: + raise NotImplementedError("Fused wgrad accumulation is not supported with grouped tensor path.") + if ctx.grouped_tensor_path: + # Wgrad GEMM writes one output per group; use num_gemms (not num_weight_params). + num_wgrad_tensors = ctx.num_gemms + grouped_wgrad = GroupedTensor.make_grouped_tensor_with_shapes( + num_tensors=num_wgrad_tensors, + shape=[(ctx.out_features, ctx.in_features)] * num_wgrad_tensors, + quantizer=None, + dtype=ctx.activation_dtype, + device=ctx.device, + ) + if ctx.single_weight: + wgrad_list = [grouped_wgrad.data.view(-1)] + else: + wgrad_list = grouped_wgrad.split_into_quantized_tensors() + elif ctx.fuse_wgrad_accumulation: wgrad_list = main_grads else: wgrad_list = [ @@ -460,32 +583,67 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ctx.activation_dtype, ) else: - inputmats = torch.split( - cast_if_needed(inp_view, ctx.activation_dtype), ctx.m_splits + if ctx.grouped_tensor_path: + inputmats = [cast_if_needed(inp_view, ctx.activation_dtype)] + else: + inputmats = torch.split( + cast_if_needed(inp_view, ctx.activation_dtype), ctx.m_splits + ) + + if ctx.grouped_tensor_path: + def grouped_gemm_wgrad_grouped_tensor(inputmat, grad_output, grouped_wgrad): + grouped_input = _make_grouped_tensor_for_m_splits( + inputmat, ctx.m_splits ) - grouped_gemm_wgrad = functools.partial( - general_grouped_gemm, - quantization_params=ctx.grad_weight_quantizers, - out_dtype=ctx.activation_dtype, - layout="NT", - grad=True, - m_splits=ctx.m_splits, - use_bias=ctx.use_bias if grad_biases[0] is None else None, - bias=biases, - use_split_accumulator=wgrad_gemm_use_split_accumulator, - accumulate=( - accumulate_wgrad_into_param_main_grad - if not getattr(weights[0], "overwrite_main_grad", False) - else False - ), - ) + grouped_grad_output = _make_grouped_tensor_for_m_splits( + grad_output, ctx.m_splits + ) + # dW = grad_output^T @ input -> (out_features, m) @ (m, in_features). + # Row-wise: A (m, n) -> cuBLAS (n, m); use A=grad_output, B=input. + # Layout NT: op(A)=(n, m), op(B)^T=(m, k) -> D = (n, k). + general_grouped_gemm_for_grouped_tensor( + grouped_grad_output, + grouped_input, + grouped_wgrad, + layout="NT", + accumulate=( + accumulate_wgrad_into_param_main_grad + if not getattr(weights[0], "overwrite_main_grad", False) + else False + ), + ) + return None, [None] * ctx.num_weight_params, None + + grouped_gemm_wgrad = grouped_gemm_wgrad_grouped_tensor + else: + grouped_gemm_wgrad = functools.partial( + general_grouped_gemm, + quantization_params=ctx.grad_weight_quantizers, + out_dtype=ctx.activation_dtype, + layout="NT", + grad=True, + m_splits=ctx.m_splits, + use_bias=ctx.use_bias if grad_biases[0] is None else None, + bias=biases, + use_split_accumulator=wgrad_gemm_use_split_accumulator, + accumulate=( + accumulate_wgrad_into_param_main_grad + if not getattr(weights[0], "overwrite_main_grad", False) + else False + ), + ) # WGRAD if ctx.wgrad_store is not None and ctx.wgrad_store.delay_wgrad_compute(): ctx.wgrad_store.put([inputmats, grad_output, wgrad_list], grouped_gemm_wgrad) + elif ctx.grouped_tensor_path: + # Pass 2D view so _make_grouped_tensor_for_m_splits gets correct logical_last_dim + grad_output_2d = grad_output[0].view(ctx.logical_first_dim, ctx.out_features) + # wgrad_list shares the same memory with grouped_wgrad + grouped_gemm_wgrad(inputmats[0], grad_output_2d, grouped_wgrad) else: _, grad_biases_, _ = grouped_gemm_wgrad(inputmats, grad_output, wgrad_list) - for i in range(ctx.num_gemms): + for i in range(ctx.num_weight_params): if grad_biases[i] is None: grad_biases[i] = grad_biases_[i] del grad_biases_ @@ -522,14 +680,14 @@ def handle_custom_ddp_from_mcore(weight, wgrad): for weight, wgrad in zip(origin_weights, wgrad_list) ] else: - wgrad_list = [None] * ctx.num_gemms + wgrad_list = [None] * (ctx.num_weight_params) if not ctx.use_bias or ( ctx.wgrad_store is not None and ctx.wgrad_store.delay_wgrad_compute() and not ctx.fp8 ): - grad_biases = [None] * ctx.num_gemms + grad_biases = [None] * (ctx.num_weight_params) if ctx.reduce_and_update_bwd_fp8_tensors: FP8GlobalStateManager.reduce_and_update_fp8_tensors(forward=False) @@ -624,6 +782,7 @@ def __init__( delay_wgrad_compute: bool = False, save_original_input: bool = False, name: Optional[str] = None, + single_weight: bool = False, ) -> None: super().__init__(name) @@ -632,6 +791,9 @@ def __init__( self.in_features = in_features self.out_features = out_features self.fuse_wgrad_accumulation = fuse_wgrad_accumulation + if single_weight: + bias = False + return_bias = False self.use_bias = bias self.return_bias = return_bias self.apply_bias = bias and not return_bias @@ -639,6 +801,7 @@ def __init__( self.ub_overlap_ag = ub_overlap_ag self.ub_name = ub_name self.save_original_input = save_original_input + self.single_weight = single_weight assert ( not ub_overlap_rs and not ub_overlap_ag ), "GroupedLinear doesn't support Userbuffer overlap." @@ -687,14 +850,31 @@ def __init__( self.sequence_parallel = (self.tp_size > 1) and sequence_parallel - for i in range(self.num_gemms): - # Construct weight parameter + if self.single_weight and self.primary_weights_in_fp8: + raise ValueError("Single weight is only supported for High precision weights.") + + if self.single_weight: + shape_weight = [(self.out_features * self.num_gemms * self.in_features,)] + shape_bias = [(self.out_features * self.num_gemms,)] + param_names = ["weight0", "bias0"] + self.num_weight_params = 1 + num_tensors = 1 + else: + shape_weight = [ + (self.out_features, self.in_features) for _ in range(self.num_gemms) + ] + shape_bias = [self.out_features for _ in range(self.num_gemms)] + num_tensors = self.num_gemms + param_names = [f"weight{i}" for i in range(self.num_gemms)]\ + + [f"bias{i}" for i in range(self.num_gemms)] + self.num_weight_params = self.num_gemms + + for i in range(num_tensors): self.register_parameter( f"weight{i}", torch.nn.Parameter( torch.empty( - self.out_features, - self.in_features, + shape_weight[i], device=device, dtype=self.params_dtype, ), @@ -710,7 +890,7 @@ def __init__( f"bias{i}", torch.nn.Parameter( torch.empty( - self.out_features, + shape_bias[i], device=device, dtype=self.params_dtype, ), @@ -729,9 +909,8 @@ def __init__( if self.wgrad_store.delay_wgrad_compute(): for name, param in self.named_parameters(): - for i in range(self.num_gemms): - if name in (f"weight{i}", f"bias{i}"): - param.skip_backward_post_hook = True + if name in param_names: + param.skip_backward_post_hook = True def set_meta_tensor(self, fwd: bool, recipe: Recipe) -> None: """Init scales and amaxes for fwd | bwd.""" @@ -749,6 +928,19 @@ def make_grouped_weights(self, defer_init=False) -> None: if defer_init: return + + if self.single_weight: + weight = getattr(self, "weight0") + logical_shape = (self.num_gemms * self.out_features, self.in_features) + self.grouped_weight_storage = GroupedTensor(num_tensors=self.num_gemms, + shape=[(self.out_features, self.in_features) for _ in range(self.num_gemms)], + quantizer=None, + dtype=self.params_dtype, + data=weight, + logical_shape=logical_shape, + ) + self.set_tensor_parallel_attributes(defer_init=defer_init) + return weights = [getattr(self, f"weight{i}") for i in range(self.num_gemms)] weight_quantizers = self._get_weight_quantizers() @@ -760,6 +952,7 @@ def make_grouped_weights(self, defer_init=False) -> None: quantizer=weight_quantizers[0], dtype=self.params_dtype, ) + self.grouped_weight_storage = grouped_weights # Copy existing params into storage. # TODO(ksivamani): Verify correctness of copy for all recipes. @@ -788,7 +981,7 @@ def set_tensor_parallel_attributes(self, defer_init=False) -> None: if not defer_init: # Set parallelism attributes for linear weights - for i in range(self.num_gemms): + for i in range(self.num_weight_params): set_tensor_model_parallel_attributes( tensor=getattr(self, f"weight{i}"), is_parallel=True, @@ -798,13 +991,9 @@ def set_tensor_parallel_attributes(self, defer_init=False) -> None: # Set parallelism attributes for linear biases if self.use_bias: - for i in range(self.num_gemms): + for i in range(self.num_weight_params): if self.parallel_mode == "row": - setattr( - getattr(self, f"bias{i}"), - "sequence_parallel", - self.sequence_parallel, - ) + setattr(getattr(self, f"bias{i}"), "sequence_parallel", self.sequence_parallel) elif self.parallel_mode == "column": set_tensor_model_parallel_attributes(getattr(self, f"bias{i}"), True, 0, 1) @@ -812,7 +1001,7 @@ def set_tensor_parallel_attributes(self, defer_init=False) -> None: def forward( self, inp: torch.Tensor, - m_splits: List[int], + m_splits: Union[List[int], torch.Tensor], is_first_microbatch: Optional[bool] = None, ) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]: """ @@ -822,8 +1011,8 @@ def forward( ---------- inp : torch.Tensor Input tensor. - m_splits : List[int] - List of integers representing the split of the input tensor. + m_splits : List[int] | torch.Tensor + List of integers or a device tensor representing the split of the input tensor. is_first_microbatch : {True, False, None}, default = None During training using either gradient accumulation or pipeline parallelism a minibatch of data is further split @@ -839,18 +1028,18 @@ def forward( produced) """ debug = self.is_debug_iter() - assert not isinstance( inp, QuantizedTensorStorage ), "GroupedLinear doesn't support input tensor in FP8." - assert len(m_splits) == self.num_gemms, "Number of splits should match number of GEMMs." - + m_splits_is_tensor = torch.is_tensor(m_splits) + num_splits = m_splits.numel() if m_splits_is_tensor else len(m_splits) + assert num_splits == self.num_gemms, "Number of splits should match number of GEMMs." is_grad_enabled = torch.is_grad_enabled() inp = self.prepare_forward(inp, num_gemms=self.num_gemms) try: weight_tensors = self._get_weight_tensors() - bias_tensors = [getattr(self, f"bias{i}") for i in range(self.num_gemms)] + bias_tensors = [getattr(self, f"bias{i}") for i in range(self.num_weight_params)] quantizers = self._get_quantizers() if not debug else self._get_debug_quantizers() @@ -880,6 +1069,7 @@ def forward( non_tensor_args = ( m_splits, + m_splits_is_tensor, self.apply_bias, is_first_microbatch, self.fp8, @@ -923,10 +1113,10 @@ def backward_dw(self): weight_params = [getattr(self, f"weight{i}") for i in range(self.num_gemms)] bias_params = [getattr(self, f"bias{i}") for i in range(self.num_gemms)] if not self.fuse_wgrad_accumulation: - for i in range(self.num_gemms): + for i in range(self.num_weight_params): weight_params[i].grad = wgrad_list[i].to(weight_params[i].dtype) if self.use_bias: - for i in range(self.num_gemms): + for i in range(self.num_weight_params): if bias_params[i].grad is None: bias_params[i].grad = grad_biases_[i].to(bias_params[i].dtype) del grad_biases_ @@ -971,7 +1161,7 @@ def _customize_quantizers_float8_current_scaling(self, fwd: bool, recipe: Recipe def _get_weight_tensors(self) -> List[Union[torch.Tensor, QuantizedTensorStorage]]: """Get the weight tensors of the module.""" - weight_tensors = [getattr(self, f"weight{i}") for i in range(self.num_gemms)] + weight_tensors = [getattr(self, f"weight{i}") for i in range(self.num_weight_params)] if not self.fp8 and any(isinstance(w, QuantizedTensorStorage) for w in weight_tensors): warnings.warn( "You are using quantized weights without quantized compute. " @@ -991,9 +1181,9 @@ def _get_weight_quantizers(self) -> List[Quantizer]: self.quantizers["scaling_fwd"][ self._offsets["weight"] + i * self._num_fp8_tensors_per_gemm["fwd"] ] - for i in range(self.num_gemms) + for i in range(self.num_weight_params) ] - for i in range(self.num_gemms): + for i in range(self.num_weight_params): weight_quantizers[i].internal = True return weight_quantizers From aa86859367ed87bdc243462b863e8b8d3eb6d02b Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 10 Feb 2026 08:46:09 +0000 Subject: [PATCH 3/5] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/pytorch/test_numerics.py | 3 +- .../transformer_engine/transformer_engine.h | 365 +++++++++--------- .../pytorch/cpp_extensions/gemm.py | 4 +- transformer_engine/pytorch/csrc/extensions.h | 10 +- .../pytorch/csrc/extensions/gemm.cpp | 10 +- .../pytorch/csrc/type_converters.cpp | 30 +- .../pytorch/module/grouped_linear.py | 51 +-- 7 files changed, 241 insertions(+), 232 deletions(-) diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index 401dc37a94..257396658d 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -2886,7 +2886,6 @@ def _pack_grouped_tensor(grouped_tensor: GroupedTensor, tensors: List[torch.Tens offset += numel - @pytest.mark.parametrize("layout", ["TN", "NN", "NT"]) @pytest.mark.parametrize("accumulate", [False]) def test_grouped_gemm_grouped_tensor(layout, accumulate): @@ -2941,6 +2940,7 @@ def test_grouped_gemm_grouped_tensor(layout, accumulate): ) device = A[0].device + def _make_grouped_tensor_from_splits(m_sizes, last_dim): first_dims = torch.tensor(m_sizes, device=device, dtype=torch.int64) return GroupedTensor.make_grouped_tensor( @@ -2995,6 +2995,7 @@ def _make_grouped_tensor_uniform(num_tensors, first_dim, last_dim): for o, o_ref in zip(out_grouped, out_ref): torch.testing.assert_close(o, o_ref, **tols) + @pytest.mark.parametrize("N", [32]) @pytest.mark.parametrize("datatype", [torch.float16, torch.bfloat16]) @pytest.mark.parametrize( diff --git a/transformer_engine/common/include/transformer_engine/transformer_engine.h b/transformer_engine/common/include/transformer_engine/transformer_engine.h index 1d4186489b..1e56ecaa9e 100644 --- a/transformer_engine/common/include/transformer_engine/transformer_engine.h +++ b/transformer_engine/common/include/transformer_engine/transformer_engine.h @@ -957,14 +957,13 @@ class TensorWrapper { NVTETensor tensor_ = nullptr; }; - /*! \struct GroupedTensorWrapper * \brief C++ wrapper for the NVTEGroupedTensor class. */ - class GroupedTensorWrapper { - public: - /*! \brief Constructs new GroupedTensorWrapper. +class GroupedTensorWrapper { + public: + /*! \brief Constructs new GroupedTensorWrapper. * * Create a new TE grouped tensor with a given logical shape. * TE grouped tensors are just wrappers on top of raw data and do not @@ -974,11 +973,11 @@ class TensorWrapper { * \param[in] logical_shape Logical 2D shape of the grouped data. * \param[in] scaling_mode Tensor data format. */ - GroupedTensorWrapper(const size_t num_tensors, const NVTEShape &logical_shape, - const NVTEScalingMode scaling_mode = NVTE_DELAYED_TENSOR_SCALING) - : tensor_(nvte_create_grouped_tensor(scaling_mode, num_tensors, logical_shape)) {} - - /*! \brief Constructs new GroupedTensorWrapper. + GroupedTensorWrapper(const size_t num_tensors, const NVTEShape &logical_shape, + const NVTEScalingMode scaling_mode = NVTE_DELAYED_TENSOR_SCALING) + : tensor_(nvte_create_grouped_tensor(scaling_mode, num_tensors, logical_shape)) {} + + /*! \brief Constructs new GroupedTensorWrapper. * * Create a new TE grouped tensor with a given logical shape. * @@ -986,184 +985,184 @@ class TensorWrapper { * \param[in] logical_shape Logical 2D shape of the grouped data. * \param[in] scaling_mode Tensor data format. */ - GroupedTensorWrapper(const size_t num_tensors, const std::vector &logical_shape, - const NVTEScalingMode scaling_mode = NVTE_DELAYED_TENSOR_SCALING) - : GroupedTensorWrapper(num_tensors, - nvte_make_shape(logical_shape.data(), logical_shape.size()), - scaling_mode) {} - - /*! \brief GroupedTensorWrapper destructor. */ - ~GroupedTensorWrapper() { nvte_destroy_grouped_tensor(tensor_); } - - GroupedTensorWrapper &operator=(const GroupedTensorWrapper &other) = delete; - GroupedTensorWrapper(const GroupedTensorWrapper &other) = delete; - - /*! \brief Constructs new GroupedTensorWrapper from existing GroupedTensorWrapper. */ - GroupedTensorWrapper(GroupedTensorWrapper &&other) { - tensor_ = other.tensor_; - other.tensor_ = nullptr; - } - - /*! \brief Assign the data from existing GroupedTensorWrapper. */ - GroupedTensorWrapper &operator=(GroupedTensorWrapper &&other) { - if (this == &other) return *this; - nvte_destroy_grouped_tensor(tensor_); - tensor_ = other.tensor_; - other.tensor_ = nullptr; - return *this; - } - - // Parameter setters - template - GroupedTensorWrapper &set_parameter(const NVTEGroupedTensorParam param, void *dptr, DType type, - const ShapeType &shape) noexcept { - NVTEShape nvte_shape = this->convertShape(shape); - NVTEBasicTensor data = {dptr, static_cast(type), nvte_shape}; - nvte_set_grouped_tensor_param(&tensor_, param, &data); - return *this; - } - - template - GroupedTensorWrapper &set_rowwise_data(void *dptr, DType type, const ShapeType &shape) noexcept { - return set_parameter(kNVTEGroupedRowwiseData, dptr, type, shape); - } - - template - GroupedTensorWrapper &set_columnwise_data(void *dptr, DType type, - const ShapeType &shape) noexcept { - return set_parameter(kNVTEGroupedColumnwiseData, dptr, type, shape); - } - - template - GroupedTensorWrapper &set_scale(void *dptr, DType type, const ShapeType &shape) noexcept { - return set_parameter(kNVTEGroupedScale, dptr, type, shape); - } - - template - GroupedTensorWrapper &set_amax(void *dptr, DType type, const ShapeType &shape) noexcept { - return set_parameter(kNVTEGroupedAmax, dptr, type, shape); - } - - template - GroupedTensorWrapper &set_rowwise_scale_inv(void *dptr, DType type, - const ShapeType &shape) noexcept { - return set_parameter(kNVTEGroupedRowwiseScaleInv, dptr, type, shape); - } - - template - GroupedTensorWrapper &set_columnwise_scale_inv(void *dptr, DType type, - const ShapeType &shape) noexcept { - return set_parameter(kNVTEGroupedColumnwiseScaleInv, dptr, type, shape); - } - - template - GroupedTensorWrapper &set_columnwise_amax(void *dptr, DType type, - const ShapeType &shape) noexcept { - return set_parameter(kNVTEGroupedColumnwiseAmax, dptr, type, shape); - } - - template - GroupedTensorWrapper &set_first_dims(void *dptr, DType type, const ShapeType &shape) noexcept { - return set_parameter(kNVTEGroupedFirstDims, dptr, type, shape); - } - - template - GroupedTensorWrapper &set_last_dims(void *dptr, DType type, const ShapeType &shape) noexcept { - return set_parameter(kNVTEGroupedLastDims, dptr, type, shape); - } - - template - GroupedTensorWrapper &set_tensor_offsets(void *dptr, DType type, + GroupedTensorWrapper(const size_t num_tensors, const std::vector &logical_shape, + const NVTEScalingMode scaling_mode = NVTE_DELAYED_TENSOR_SCALING) + : GroupedTensorWrapper(num_tensors, + nvte_make_shape(logical_shape.data(), logical_shape.size()), + scaling_mode) {} + + /*! \brief GroupedTensorWrapper destructor. */ + ~GroupedTensorWrapper() { nvte_destroy_grouped_tensor(tensor_); } + + GroupedTensorWrapper &operator=(const GroupedTensorWrapper &other) = delete; + GroupedTensorWrapper(const GroupedTensorWrapper &other) = delete; + + /*! \brief Constructs new GroupedTensorWrapper from existing GroupedTensorWrapper. */ + GroupedTensorWrapper(GroupedTensorWrapper &&other) { + tensor_ = other.tensor_; + other.tensor_ = nullptr; + } + + /*! \brief Assign the data from existing GroupedTensorWrapper. */ + GroupedTensorWrapper &operator=(GroupedTensorWrapper &&other) { + if (this == &other) return *this; + nvte_destroy_grouped_tensor(tensor_); + tensor_ = other.tensor_; + other.tensor_ = nullptr; + return *this; + } + + // Parameter setters + template + GroupedTensorWrapper &set_parameter(const NVTEGroupedTensorParam param, void *dptr, DType type, + const ShapeType &shape) noexcept { + NVTEShape nvte_shape = this->convertShape(shape); + NVTEBasicTensor data = {dptr, static_cast(type), nvte_shape}; + nvte_set_grouped_tensor_param(&tensor_, param, &data); + return *this; + } + + template + GroupedTensorWrapper &set_rowwise_data(void *dptr, DType type, const ShapeType &shape) noexcept { + return set_parameter(kNVTEGroupedRowwiseData, dptr, type, shape); + } + + template + GroupedTensorWrapper &set_columnwise_data(void *dptr, DType type, const ShapeType &shape) noexcept { - return set_parameter(kNVTEGroupedTensorOffsets, dptr, type, shape); - } - - // Parameter getters - NVTEBasicTensor get_parameter(const NVTEGroupedTensorParam param) const noexcept { - return nvte_get_grouped_tensor_param(tensor_, param); - } - - NVTEBasicTensor get_rowwise_data() const noexcept { - return get_parameter(kNVTEGroupedRowwiseData); - } - - NVTEBasicTensor get_columnwise_data() const noexcept { - return get_parameter(kNVTEGroupedColumnwiseData); - } - - NVTEBasicTensor get_scale() const noexcept { return get_parameter(kNVTEGroupedScale); } - - NVTEBasicTensor get_amax() const noexcept { return get_parameter(kNVTEGroupedAmax); } - - NVTEBasicTensor get_rowwise_scale_inv() const noexcept { - return get_parameter(kNVTEGroupedRowwiseScaleInv); - } - - NVTEBasicTensor get_columnwise_scale_inv() const noexcept { - return get_parameter(kNVTEGroupedColumnwiseScaleInv); - } - - NVTEBasicTensor get_columnwise_amax() const noexcept { - return get_parameter(kNVTEGroupedColumnwiseAmax); - } - - NVTEBasicTensor get_first_dims() const noexcept { return get_parameter(kNVTEGroupedFirstDims); } - - NVTEBasicTensor get_last_dims() const noexcept { return get_parameter(kNVTEGroupedLastDims); } - - NVTEBasicTensor get_tensor_offsets() const noexcept { - return get_parameter(kNVTEGroupedTensorOffsets); - } - - /*! \brief Get an underlying NVTEGroupedTensor. + return set_parameter(kNVTEGroupedColumnwiseData, dptr, type, shape); + } + + template + GroupedTensorWrapper &set_scale(void *dptr, DType type, const ShapeType &shape) noexcept { + return set_parameter(kNVTEGroupedScale, dptr, type, shape); + } + + template + GroupedTensorWrapper &set_amax(void *dptr, DType type, const ShapeType &shape) noexcept { + return set_parameter(kNVTEGroupedAmax, dptr, type, shape); + } + + template + GroupedTensorWrapper &set_rowwise_scale_inv(void *dptr, DType type, + const ShapeType &shape) noexcept { + return set_parameter(kNVTEGroupedRowwiseScaleInv, dptr, type, shape); + } + + template + GroupedTensorWrapper &set_columnwise_scale_inv(void *dptr, DType type, + const ShapeType &shape) noexcept { + return set_parameter(kNVTEGroupedColumnwiseScaleInv, dptr, type, shape); + } + + template + GroupedTensorWrapper &set_columnwise_amax(void *dptr, DType type, + const ShapeType &shape) noexcept { + return set_parameter(kNVTEGroupedColumnwiseAmax, dptr, type, shape); + } + + template + GroupedTensorWrapper &set_first_dims(void *dptr, DType type, const ShapeType &shape) noexcept { + return set_parameter(kNVTEGroupedFirstDims, dptr, type, shape); + } + + template + GroupedTensorWrapper &set_last_dims(void *dptr, DType type, const ShapeType &shape) noexcept { + return set_parameter(kNVTEGroupedLastDims, dptr, type, shape); + } + + template + GroupedTensorWrapper &set_tensor_offsets(void *dptr, DType type, + const ShapeType &shape) noexcept { + return set_parameter(kNVTEGroupedTensorOffsets, dptr, type, shape); + } + + // Parameter getters + NVTEBasicTensor get_parameter(const NVTEGroupedTensorParam param) const noexcept { + return nvte_get_grouped_tensor_param(tensor_, param); + } + + NVTEBasicTensor get_rowwise_data() const noexcept { + return get_parameter(kNVTEGroupedRowwiseData); + } + + NVTEBasicTensor get_columnwise_data() const noexcept { + return get_parameter(kNVTEGroupedColumnwiseData); + } + + NVTEBasicTensor get_scale() const noexcept { return get_parameter(kNVTEGroupedScale); } + + NVTEBasicTensor get_amax() const noexcept { return get_parameter(kNVTEGroupedAmax); } + + NVTEBasicTensor get_rowwise_scale_inv() const noexcept { + return get_parameter(kNVTEGroupedRowwiseScaleInv); + } + + NVTEBasicTensor get_columnwise_scale_inv() const noexcept { + return get_parameter(kNVTEGroupedColumnwiseScaleInv); + } + + NVTEBasicTensor get_columnwise_amax() const noexcept { + return get_parameter(kNVTEGroupedColumnwiseAmax); + } + + NVTEBasicTensor get_first_dims() const noexcept { return get_parameter(kNVTEGroupedFirstDims); } + + NVTEBasicTensor get_last_dims() const noexcept { return get_parameter(kNVTEGroupedLastDims); } + + NVTEBasicTensor get_tensor_offsets() const noexcept { + return get_parameter(kNVTEGroupedTensorOffsets); + } + + /*! \brief Get an underlying NVTEGroupedTensor. * * \return NVTEGroupedTensor held by this GroupedTensorWrapper. */ - NVTEGroupedTensor data() const noexcept { return tensor_; } - - /*! \brief Get the number of tensors in this GroupedTensorWrapper. */ - size_t num_tensors() const noexcept { - if (tensor_ == nullptr) return 0; - return nvte_grouped_tensor_num_tensors(tensor_); - } - - /*! \brief Get the data type of this GroupedTensorWrapper. */ - DType dtype() const noexcept { - if (tensor_ == nullptr) return DType::kNumTypes; - return static_cast(nvte_grouped_tensor_type(tensor_)); - } - - /*! \brief Get a scaling mode of the grouped tensor. */ - NVTEScalingMode scaling_mode() const noexcept { - if (tensor_ == nullptr) return NVTE_DELAYED_TENSOR_SCALING; - return nvte_grouped_tensor_scaling_mode(tensor_); - } - - /*! \brief Get the logical shape of this GroupedTensorWrapper. */ - const NVTEShape logical_shape() const noexcept { - if (tensor_ == nullptr) { - return emptyShape; - } - return nvte_get_grouped_tensor_logical_shape(tensor_); - } - - static constexpr size_t defaultData = 1; - static constexpr NVTEShape defaultShape = { - {defaultData, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, 1}; - static constexpr NVTEShape emptyShape = {{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, 1}; - - private: - NVTEShape convertShape(const NVTEShape &s) { return s; } - - NVTEShape convertShape(const std::vector &s) { - return nvte_make_shape(s.data(), s.size()); - } - - /*! \brief Wrapped NVTEGroupedTensor. */ - NVTEGroupedTensor tensor_ = nullptr; - }; - - /*! \enum Float8BlockScaleTensorFormat + NVTEGroupedTensor data() const noexcept { return tensor_; } + + /*! \brief Get the number of tensors in this GroupedTensorWrapper. */ + size_t num_tensors() const noexcept { + if (tensor_ == nullptr) return 0; + return nvte_grouped_tensor_num_tensors(tensor_); + } + + /*! \brief Get the data type of this GroupedTensorWrapper. */ + DType dtype() const noexcept { + if (tensor_ == nullptr) return DType::kNumTypes; + return static_cast(nvte_grouped_tensor_type(tensor_)); + } + + /*! \brief Get a scaling mode of the grouped tensor. */ + NVTEScalingMode scaling_mode() const noexcept { + if (tensor_ == nullptr) return NVTE_DELAYED_TENSOR_SCALING; + return nvte_grouped_tensor_scaling_mode(tensor_); + } + + /*! \brief Get the logical shape of this GroupedTensorWrapper. */ + const NVTEShape logical_shape() const noexcept { + if (tensor_ == nullptr) { + return emptyShape; + } + return nvte_get_grouped_tensor_logical_shape(tensor_); + } + + static constexpr size_t defaultData = 1; + static constexpr NVTEShape defaultShape = { + {defaultData, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, 1}; + static constexpr NVTEShape emptyShape = {{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, 1}; + + private: + NVTEShape convertShape(const NVTEShape &s) { return s; } + + NVTEShape convertShape(const std::vector &s) { + return nvte_make_shape(s.data(), s.size()); + } + + /*! \brief Wrapped NVTEGroupedTensor. */ + NVTEGroupedTensor tensor_ = nullptr; +}; + +/*! \enum Float8BlockScaleTensorFormat * \brief Data format for an FP8 block-scaled tensor */ /*! \warning Deprecated */ diff --git a/transformer_engine/pytorch/cpp_extensions/gemm.py b/transformer_engine/pytorch/cpp_extensions/gemm.py index 631569ac91..5ef0ef741e 100644 --- a/transformer_engine/pytorch/cpp_extensions/gemm.py +++ b/transformer_engine/pytorch/cpp_extensions/gemm.py @@ -310,6 +310,7 @@ def general_grouped_gemm( return out, bias, gelu_input + def get_grouped_gemm_setup_workspace_size(num_tensors: int) -> int: """Return workspace size for grouped GEMM pointer setup. Must match GroupedGemmSetupWorkspace::required_setup_size in cublaslt_grouped_gemm.cu. @@ -349,7 +350,8 @@ def general_grouped_gemm_for_grouped_tensor( num_tensors = A.num_tensors assert A.num_tensors == B.num_tensors == out.num_tensors, ( - f"GroupedTensor num_tensors must match: A={A.num_tensors}, B={B.num_tensors}, out={out.num_tensors}" + f"GroupedTensor num_tensors must match: A={A.num_tensors}, B={B.num_tensors}," + f" out={out.num_tensors}" ) if out.data is not None: diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index fa7e27ed00..c365eb4bfc 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -150,10 +150,12 @@ std::optional> te_general_grouped_gemm( std::vector pre_gelu_out, bool grad, std::vector workspace, size_t workspaceSize, bool accumulate, bool use_split_accumulator, int math_sm_count); -py::object te_general_grouped_gemm_for_grouped_tensor( - py::handle A, bool transa, py::handle B, bool transb, py::object C, py::handle D, - at::Tensor alpha, at::Tensor beta, at::Tensor workspace_setup, at::Tensor workspace_cublas, - int math_sm_count); +py::object te_general_grouped_gemm_for_grouped_tensor(py::handle A, bool transa, py::handle B, + bool transb, py::object C, py::handle D, + at::Tensor alpha, at::Tensor beta, + at::Tensor workspace_setup, + at::Tensor workspace_cublas, + int math_sm_count); /*************************************************************************************************** * Transpose diff --git a/transformer_engine/pytorch/csrc/extensions/gemm.cpp b/transformer_engine/pytorch/csrc/extensions/gemm.cpp index 765f346bb2..d5a8ff5489 100644 --- a/transformer_engine/pytorch/csrc/extensions/gemm.cpp +++ b/transformer_engine/pytorch/csrc/extensions/gemm.cpp @@ -570,10 +570,12 @@ std::optional> te_general_grouped_gemm( return bias; } -py::object te_general_grouped_gemm_for_grouped_tensor( - py::handle A, bool transa, py::handle B, bool transb, py::object C, py::handle D, - at::Tensor alpha, at::Tensor beta, at::Tensor workspace_setup, at::Tensor workspace_cublas, - int math_sm_count) { +py::object te_general_grouped_gemm_for_grouped_tensor(py::handle A, bool transa, py::handle B, + bool transb, py::object C, py::handle D, + at::Tensor alpha, at::Tensor beta, + at::Tensor workspace_setup, + at::Tensor workspace_cublas, + int math_sm_count) { using namespace transformer_engine::pytorch::detail; init_extension(); diff --git a/transformer_engine/pytorch/csrc/type_converters.cpp b/transformer_engine/pytorch/csrc/type_converters.cpp index e6e392b894..07961d85d4 100644 --- a/transformer_engine/pytorch/csrc/type_converters.cpp +++ b/transformer_engine/pytorch/csrc/type_converters.cpp @@ -199,8 +199,6 @@ DType GetTransformerEngineDTypeForScaleInv(py::handle quantizer, at::Tensor scal return GetTransformerEngineDType(scale_inv.scalar_type()); } - - GroupedTensorWrapper GroupedTensorFromPyTorchGroupedTensor(py::handle tensor) { // Returns a GroupedTensorWrapper from a PyTorch GroupedTensor. const auto num_tensors = tensor.attr("num_tensors").cast(); @@ -218,17 +216,17 @@ GroupedTensorWrapper GroupedTensorFromPyTorchGroupedTensor(py::handle tensor) { // Rowwise data if (!tensor.attr("data").is_none()) { const auto &data = tensor.attr("data").cast(); - DType data_dtype = quantizer.is_none() ? GetTransformerEngineDType(data.scalar_type()) : quantizer_dtype; - ret.set_rowwise_data(data.data_ptr(), data_dtype, - getTensorShape(data)); + DType data_dtype = + quantizer.is_none() ? GetTransformerEngineDType(data.scalar_type()) : quantizer_dtype; + ret.set_rowwise_data(data.data_ptr(), data_dtype, getTensorShape(data)); } // Columnwise data if (!tensor.attr("columnwise_data").is_none()) { const auto &data = tensor.attr("columnwise_data").cast(); - DType data_dtype = quantizer.is_none() ? GetTransformerEngineDType(data.scalar_type()) : quantizer_dtype; - ret.set_columnwise_data(data.data_ptr(), data_dtype, - getTensorShape(data)); + DType data_dtype = + quantizer.is_none() ? GetTransformerEngineDType(data.scalar_type()) : quantizer_dtype; + ret.set_columnwise_data(data.data_ptr(), data_dtype, getTensorShape(data)); } // Scale @@ -242,7 +240,7 @@ GroupedTensorWrapper GroupedTensorFromPyTorchGroupedTensor(py::handle tensor) { if (!tensor.attr("amax").is_none()) { const auto &amax = tensor.attr("amax").cast(); ret.set_amax(amax.data_ptr(), GetTransformerEngineDType(amax.scalar_type()), - getTensorShape(amax)); + getTensorShape(amax)); } if (!tensor.attr("columnwise_amax").is_none()) { const auto &amax = tensor.attr("columnwise_amax").cast(); @@ -260,26 +258,26 @@ GroupedTensorWrapper GroupedTensorFromPyTorchGroupedTensor(py::handle tensor) { if (!tensor.attr("columnwise_scale_inv").is_none()) { const auto &scale_inv = tensor.attr("columnwise_scale_inv").cast(); ret.set_columnwise_scale_inv(scale_inv.data_ptr(), - GetTransformerEngineDTypeForScaleInv(quantizer, scale_inv), - getTensorShape(scale_inv)); + GetTransformerEngineDTypeForScaleInv(quantizer, scale_inv), + getTensorShape(scale_inv)); } // Shape metadata if (!tensor.attr("first_dims").is_none()) { const auto &first_dims = tensor.attr("first_dims").cast(); ret.set_first_dims(first_dims.data_ptr(), GetTransformerEngineDType(first_dims.scalar_type()), - getTensorShape(first_dims)); + getTensorShape(first_dims)); } if (!tensor.attr("last_dims").is_none()) { const auto &last_dims = tensor.attr("last_dims").cast(); ret.set_last_dims(last_dims.data_ptr(), GetTransformerEngineDType(last_dims.scalar_type()), getTensorShape(last_dims)); -} + } if (!tensor.attr("tensor_offsets").is_none()) { const auto &tensor_offsets = tensor.attr("tensor_offsets").cast(); ret.set_tensor_offsets(tensor_offsets.data_ptr(), - GetTransformerEngineDType(tensor_offsets.scalar_type()), - getTensorShape(tensor_offsets)); + GetTransformerEngineDType(tensor_offsets.scalar_type()), + getTensorShape(tensor_offsets)); } return ret; @@ -287,4 +285,4 @@ GroupedTensorWrapper GroupedTensorFromPyTorchGroupedTensor(py::handle tensor) { } // namespace detail -} // namespace transformer_engine::pytorch \ No newline at end of file +} // namespace transformer_engine::pytorch diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index 2dc9794d0f..be3784d584 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -87,9 +87,7 @@ def _clone_grouped_tensor_with_data( ) -def _make_grouped_tensor_for_m_splits( - data: torch.Tensor, m_splits: torch.Tensor -) -> GroupedTensor: +def _make_grouped_tensor_for_m_splits(data: torch.Tensor, m_splits: torch.Tensor) -> GroupedTensor: # Use data.shape[0] to avoid first_dims.sum().item() D2H copy (breaks CUDA graph) logical_first_dim = data.shape[0] grouped = GroupedTensor.make_grouped_tensor( @@ -184,7 +182,7 @@ def forward( if output_quantizers[0] is not None: for output_quantizer in output_quantizers: output_quantizer.set_usage(rowwise=True, columnwise=False) - no_quantization = (not fp8 and weight_quantizers[0] is None) + no_quantization = not fp8 and weight_quantizers[0] is None # Initialize input tensors in_features = module.in_features @@ -199,7 +197,9 @@ def forward( inp_view_cast = None if m_splits_is_tensor and not no_quantization: # TODO: Support this path. - raise ValueError("GroupedGEMM with grouped tensor path with quantization is not supported yet.") + raise ValueError( + "GroupedGEMM with grouped tensor path with quantization is not supported yet." + ) grouped_tensor_path = no_quantization and m_splits_is_tensor if fp8 and not debug: @@ -218,7 +218,9 @@ def forward( ) else: inp_view_cast = cast_if_needed(inp_view, activation_dtype) - inputmats = [inp_view_cast] if grouped_tensor_path else torch.split(inp_view_cast, m_splits) + inputmats = ( + [inp_view_cast] if grouped_tensor_path else torch.split(inp_view_cast, m_splits) + ) if cpu_offloading: start_offload(*inputmats) @@ -354,7 +356,9 @@ def forward( # the main_grad buffer lazily before backprop if hasattr(weights[0], "__fsdp_param__"): # MCore FSDP creates main_grad lazily before backward - ctx.main_grad_funcs = [weights[i].get_main_grad for i in range(num_weight_params)] + ctx.main_grad_funcs = [ + weights[i].get_main_grad for i in range(num_weight_params) + ] else: ctx.main_grad_funcs = [ lambda j=i: weights[j].main_grad for i in range(num_weight_params) @@ -467,9 +471,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], if ctx.grouped_tensor_path: out = cast_if_needed(grad_output_view, ctx.activation_dtype) grad_output = [out] - grouped_grad_output = _make_grouped_tensor_for_m_splits( - out, m_splits - ) + grouped_grad_output = _make_grouped_tensor_for_m_splits(out, m_splits) else: grad_output = torch.split( cast_if_needed(grad_output_view, ctx.activation_dtype), @@ -536,7 +538,9 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ) grouped_wgrad = None if ctx.grouped_tensor_path and ctx.fuse_wgrad_accumulation: - raise NotImplementedError("Fused wgrad accumulation is not supported with grouped tensor path.") + raise NotImplementedError( + "Fused wgrad accumulation is not supported with grouped tensor path." + ) if ctx.grouped_tensor_path: # Wgrad GEMM writes one output per group; use num_gemms (not num_weight_params). num_wgrad_tensors = ctx.num_gemms @@ -548,7 +552,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], device=ctx.device, ) if ctx.single_weight: - wgrad_list = [grouped_wgrad.data.view(-1)] + wgrad_list = [grouped_wgrad.data.view(-1)] else: wgrad_list = grouped_wgrad.split_into_quantized_tensors() elif ctx.fuse_wgrad_accumulation: @@ -591,10 +595,9 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ) if ctx.grouped_tensor_path: + def grouped_gemm_wgrad_grouped_tensor(inputmat, grad_output, grouped_wgrad): - grouped_input = _make_grouped_tensor_for_m_splits( - inputmat, ctx.m_splits - ) + grouped_input = _make_grouped_tensor_for_m_splits(inputmat, ctx.m_splits) grouped_grad_output = _make_grouped_tensor_for_m_splits( grad_output, ctx.m_splits ) @@ -860,13 +863,12 @@ def __init__( self.num_weight_params = 1 num_tensors = 1 else: - shape_weight = [ - (self.out_features, self.in_features) for _ in range(self.num_gemms) - ] + shape_weight = [(self.out_features, self.in_features) for _ in range(self.num_gemms)] shape_bias = [self.out_features for _ in range(self.num_gemms)] num_tensors = self.num_gemms - param_names = [f"weight{i}" for i in range(self.num_gemms)]\ - + [f"bias{i}" for i in range(self.num_gemms)] + param_names = [f"weight{i}" for i in range(self.num_gemms)] + [ + f"bias{i}" for i in range(self.num_gemms) + ] self.num_weight_params = self.num_gemms for i in range(num_tensors): @@ -928,11 +930,12 @@ def make_grouped_weights(self, defer_init=False) -> None: if defer_init: return - + if self.single_weight: weight = getattr(self, "weight0") logical_shape = (self.num_gemms * self.out_features, self.in_features) - self.grouped_weight_storage = GroupedTensor(num_tensors=self.num_gemms, + self.grouped_weight_storage = GroupedTensor( + num_tensors=self.num_gemms, shape=[(self.out_features, self.in_features) for _ in range(self.num_gemms)], quantizer=None, dtype=self.params_dtype, @@ -993,7 +996,9 @@ def set_tensor_parallel_attributes(self, defer_init=False) -> None: if self.use_bias: for i in range(self.num_weight_params): if self.parallel_mode == "row": - setattr(getattr(self, f"bias{i}"), "sequence_parallel", self.sequence_parallel) + setattr( + getattr(self, f"bias{i}"), "sequence_parallel", self.sequence_parallel + ) elif self.parallel_mode == "column": set_tensor_model_parallel_attributes(getattr(self, f"bias{i}"), True, 0, 1) From 98cb4fad5819d634d7a9ace7977224916ac9e1dc Mon Sep 17 00:00:00 2001 From: Varun Thumbe Date: Wed, 11 Feb 2026 03:11:53 +0000 Subject: [PATCH 4/5] remove changes not needed for bf16 Signed-off-by: Varun Thumbe --- tests/cpp/operator/test_grouped_gemm.cu | 9 ++--- tests/pytorch/test_numerics.py | 39 ++++++++----------- .../common/gemm/cublaslt_grouped_gemm.cu | 30 ++------------ 3 files changed, 25 insertions(+), 53 deletions(-) diff --git a/tests/cpp/operator/test_grouped_gemm.cu b/tests/cpp/operator/test_grouped_gemm.cu index 65d0248b0a..2ea4ea0cfa 100644 --- a/tests/cpp/operator/test_grouped_gemm.cu +++ b/tests/cpp/operator/test_grouped_gemm.cu @@ -44,10 +44,10 @@ enum class ShapeCase { size_t grouped_setup_workspace_size(const size_t num_tensors) { const size_t ptr_bytes = num_tensors * sizeof(void*); const size_t int_bytes = num_tensors * sizeof(int); - // Layout: 8 pointer arrays (A, B, C, D, alpha, beta, a_scale, b_scale) + 6 int arrays - size_t size = 8 * ptr_bytes + 6 * int_bytes; + // Layout: 6 pointer arrays (A, B, C, D, alpha, beta) + 6 int arrays (a_rows, a_cols, b_rows, b_cols, d_rows, d_cols) + size_t size = 6 * ptr_bytes + 6 * int_bytes; const size_t alignment = 256; - size = ((size + alignment - 1) / alignment) * alignment; + size = ((size + alignment - 1) / alignment) * alignment; return size; } @@ -291,8 +291,7 @@ std::string MakeGroupedGemmTestName(const testing::TestParamInfo kTestParams = { - // FP8 tests (each tensor has random mean/stddev -> different scales) - {InputCase::kFP8Current, true, false, ShapeCase::kAllDifferent, false}, + // Basic tests {InputCase::kFP8Current, false, true, ShapeCase::kAllDifferent, false}, {InputCase::kFP8Current, false, false, ShapeCase::kAllSame, false}, {InputCase::kBF16, true, false, ShapeCase::kSameFirst, false}, diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index 401dc37a94..2224363e30 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -2018,7 +2018,7 @@ def test_grouped_linear_m_splits_tensor(single_weight): dtype = torch.bfloat16 m_total = int(m_splits.sum().item()) - reference = GroupedLinear( + reference_model = GroupedLinear( num_gemms, in_features, out_features, @@ -2026,11 +2026,11 @@ def test_grouped_linear_m_splits_tensor(single_weight): params_dtype=dtype, device="cuda", single_weight=False, - ).eval() + ) with torch.no_grad(): - ref_weights = [getattr(reference, f"weight{i}") for i in range(num_gemms)] + ref_weights = [getattr(reference_model, f"weight{i}") for i in range(num_gemms)] - model_under_test = GroupedLinear( + test_model = GroupedLinear( num_gemms, in_features, out_features, @@ -2038,26 +2038,26 @@ def test_grouped_linear_m_splits_tensor(single_weight): params_dtype=dtype, device="cuda", single_weight=single_weight, - ).eval() + ) with torch.no_grad(): if single_weight: for i, w in enumerate( - model_under_test.grouped_weight_storage.split_into_quantized_tensors() + test_model.grouped_weight_storage.split_into_quantized_tensors() ): w.copy_(ref_weights[i]) else: for i in range(num_gemms): - getattr(model_under_test, f"weight{i}").copy_(ref_weights[i]) + getattr(test_model, f"weight{i}").copy_(ref_weights[i]) inp = torch.randn(m_total, in_features, device="cuda", dtype=dtype, requires_grad=True) inp_ref = inp.detach().clone().requires_grad_() if single_weight: - out = model_under_test(inp, m_splits) - out_ref = reference(inp_ref, m_splits) + out = test_model(inp, m_splits) + out_ref = reference_model(inp_ref, m_splits) else: - out = model_under_test(inp, m_splits) - out_ref = model_under_test(inp_ref, m_splits_list) + out = test_model(inp, m_splits) + out_ref = reference_model(inp_ref, m_splits_list) torch.testing.assert_close(out, out_ref, **dtype_tols(dtype)) @@ -2067,10 +2067,10 @@ def test_grouped_linear_m_splits_tensor(single_weight): torch.testing.assert_close(inp.grad, inp_ref.grad, **dtype_tols(dtype)) if single_weight: ref_wgrad = torch.cat( - [getattr(reference, f"weight{i}").grad.view(-1) for i in range(num_gemms)] + [getattr(reference_model, f"weight{i}").grad.view(-1) for i in range(num_gemms)] ) torch.testing.assert_close( - getattr(model_under_test, "weight0").grad, ref_wgrad, **dtype_tols(dtype) + getattr(test_model, "weight0").grad, ref_wgrad, **dtype_tols(dtype) ) @@ -2802,15 +2802,10 @@ def test_grouped_gemm(shape, dtype, layout, accumulate, use_cutlass): torch.manual_seed(0) z, m, k, n = shape - if z == 1: - m_splits = [m] - else: - split_points = torch.randperm(m - 1)[: z - 1] + 1 - split_points = torch.sort(split_points).values.tolist() - m_splits = [split_points[0]] - m_splits += [b - a for a, b in zip(split_points[:-1], split_points[1:])] - m_splits.append(m - split_points[-1]) - assert sum(m_splits) == m and len(m_splits) == z + dist = torch.sort(torch.randint(0, m, (z - 1,))).values.tolist() + m_splits = torch.tensor(dist + [m]) - torch.tensor([0] + dist) + assert m_splits.sum() == m and len(m_splits) == z + m_splits = m_splits.tolist() if layout == "TN": A = [torch.randn(n, k, dtype=dtype, device="cuda") for _ in range(z)] # weight diff --git a/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu b/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu index 5f33ab2733..f1333f3491 100644 --- a/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu @@ -119,42 +119,27 @@ struct GroupedGemmSetupWorkspace { int *d_cols; // N (last dim) - also used for C // Initialize from workspace buffer - // Layout: all pointer arrays first (16-byte aligned for cuBLAS), then int arrays + // Layout: all pointer arrays first (8-byte aligned), then int arrays (4-byte aligned) static GroupedGemmSetupWorkspace from_buffers(char *setup_ws_ptr, size_t num_tensors) { GroupedGemmSetupWorkspace ws; size_t offset = 0; const size_t ptr_size = num_tensors * sizeof(void *); const size_t int_size = num_tensors * sizeof(int); - constexpr size_t kPtrAlignment = 16; // cuBLAS requires 16-byte alignment for pointer arrays - - // Helper to align offset to kPtrAlignment - auto align_offset = [&]() { - offset = (offset + kPtrAlignment - 1) / kPtrAlignment * kPtrAlignment; - }; - - // Pointer arrays first (all 16-byte aligned for cuBLAS grouped GEMM) - align_offset(); + // Pointer arrays first (all 8-byte aligned) ws.A_ptrs = reinterpret_cast(setup_ws_ptr + offset); offset += ptr_size; - align_offset(); ws.B_ptrs = reinterpret_cast(setup_ws_ptr + offset); offset += ptr_size; - align_offset(); ws.C_ptrs = reinterpret_cast(setup_ws_ptr + offset); offset += ptr_size; - align_offset(); ws.D_ptrs = reinterpret_cast(setup_ws_ptr + offset); offset += ptr_size; - align_offset(); ws.alpha_ptrs = reinterpret_cast(setup_ws_ptr + offset); offset += ptr_size; - align_offset(); ws.beta_ptrs = reinterpret_cast(setup_ws_ptr + offset); offset += ptr_size; - // Int arrays for storage dimensions (4-byte aligned) - align_offset(); ws.a_rows = reinterpret_cast(setup_ws_ptr + offset); offset += int_size; ws.a_cols = reinterpret_cast(setup_ws_ptr + offset); @@ -174,11 +159,8 @@ struct GroupedGemmSetupWorkspace { static size_t required_setup_size(size_t num_tensors, size_t alignment) { const size_t ptr_size = num_tensors * sizeof(void *); const size_t int_size = num_tensors * sizeof(int); - constexpr size_t kPtrAlignment = 16; // Must match from_buffers - // Layout: 8 ptr arrays (each 16-byte aligned), then 6 int arrays - // Each ptr array takes ptr_size bytes but needs to start at 16-byte boundary - auto aligned_ptr_size = ((ptr_size + kPtrAlignment - 1) / kPtrAlignment) * kPtrAlignment; - size_t size = 8 * aligned_ptr_size + 6 * int_size; + // Layout: 6 ptr arrays, then 6 int arrays + size_t size = 6 * ptr_size + 6 * int_size; size = ((size + alignment - 1) / alignment) * alignment; return size; } @@ -401,10 +383,6 @@ inline void init_matmul_desc(cublasLtMatmulDescOpaque_t &matmulDesc, cublasOpera NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(&matmulDesc, CUBLASLT_MATMUL_DESC_BETA_BATCH_STRIDE, &alphabeta_batch_stride, sizeof(int64_t))); - // Fast accumulation mode: 0 = split accumulator (more accurate), 1 = fast accumulator - int8_t fastAccuMode = 0; // Use split accumulator for accuracy - NVTE_CHECK_CUBLAS(cublasLtMatmulDescSetAttribute(&matmulDesc, CUBLASLT_MATMUL_DESC_FAST_ACCUM, - &fastAccuMode, sizeof(fastAccuMode))); } inline void set_fp8_scale_pointers(cublasLtMatmulDescOpaque_t &matmulDesc, From 1d041f89023d000c1340c66a26612897c59a2275 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 11 Feb 2026 03:14:59 +0000 Subject: [PATCH 5/5] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/cpp/operator/test_grouped_gemm.cu | 2 +- tests/pytorch/test_numerics.py | 4 +--- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/tests/cpp/operator/test_grouped_gemm.cu b/tests/cpp/operator/test_grouped_gemm.cu index 2ea4ea0cfa..73969ca297 100644 --- a/tests/cpp/operator/test_grouped_gemm.cu +++ b/tests/cpp/operator/test_grouped_gemm.cu @@ -47,7 +47,7 @@ size_t grouped_setup_workspace_size(const size_t num_tensors) { // Layout: 6 pointer arrays (A, B, C, D, alpha, beta) + 6 int arrays (a_rows, a_cols, b_rows, b_cols, d_rows, d_cols) size_t size = 6 * ptr_bytes + 6 * int_bytes; const size_t alignment = 256; - size = ((size + alignment - 1) / alignment) * alignment; + size = ((size + alignment - 1) / alignment) * alignment; return size; } diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index 83831c81d0..366b2a62c5 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -2041,9 +2041,7 @@ def test_grouped_linear_m_splits_tensor(single_weight): ) with torch.no_grad(): if single_weight: - for i, w in enumerate( - test_model.grouped_weight_storage.split_into_quantized_tensors() - ): + for i, w in enumerate(test_model.grouped_weight_storage.split_into_quantized_tensors()): w.copy_(ref_weights[i]) else: for i in range(num_gemms):