diff --git a/tests/pytorch/test_grouped_tensor.py b/tests/pytorch/test_grouped_tensor.py new file mode 100644 index 0000000000..318009c669 --- /dev/null +++ b/tests/pytorch/test_grouped_tensor.py @@ -0,0 +1,385 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Tests for GroupedTensor class""" + +from typing import List, Tuple +import pytest +import torch +import transformer_engine.pytorch as te +from transformer_engine.pytorch.tensor.storage.grouped_tensor import GroupedTensor +from transformer_engine.pytorch import ( + Quantizer, + Float8Quantizer, + Float8CurrentScalingQuantizer, + Float8BlockQuantizer, + MXFP8Quantizer, + NVFP4Quantizer, +) +from transformer_engine.pytorch.constants import TE_DType_To_Torch +import transformer_engine_torch as tex + +# Check available recipes +fp8_available, reason_for_no_fp8 = te.is_fp8_available(return_reason=True) +fp8_block_scaling_available, reason_for_no_fp8_block_scaling = te.is_fp8_block_scaling_available( + return_reason=True +) +mxfp8_available, reason_for_no_mxfp8 = te.is_mxfp8_available(return_reason=True) +nvfp4_available, reason_for_no_nvfp4 = te.is_nvfp4_available(return_reason=True) + +_quantization_params = [ + pytest.param( + "fp8_delayed_scaling", + marks=pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8), + ), + pytest.param( + "fp8_current_scaling", + marks=pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8), + ), + pytest.param( + "fp8_blockwise", + marks=pytest.mark.skipif( + not fp8_block_scaling_available, reason=reason_for_no_fp8_block_scaling + ), + ), + pytest.param( + "mxfp8", + marks=pytest.mark.skipif(not mxfp8_available, reason=reason_for_no_mxfp8), + ), + pytest.param( + "nvfp4", + marks=pytest.mark.skipif(not nvfp4_available, reason=reason_for_no_nvfp4), + ), +] + + +def make_quantizer(quantization: str, num_tensors: int, shape: List[Tuple[int, int]]) -> Quantizer: + """Create quantizers for given quantization scheme""" + + if quantization == "fp8_delayed_scaling": + quantizer = Float8Quantizer( + scale=torch.ones(1, dtype=torch.float32, device="cuda"), + amax=torch.zeros(1, dtype=torch.float32, device="cuda"), + fp8_dtype=tex.DType.kFloat8E4M3, + ) + elif quantization == "fp8_current_scaling": + quantizer = Float8CurrentScalingQuantizer( + fp8_dtype=tex.DType.kFloat8E4M3, + device="cuda", + ) + quantizer.set_usage(rowwise=True, columnwise=False) + elif quantization == "fp8_blockwise": + quantizer = Float8BlockQuantizer( + fp8_dtype=tex.DType.kFloat8E4M3, + rowwise=True, + columnwise=False, + force_pow_2_scales=True, + amax_epsilon=0.0, + block_scaling_dim=1, + ) + elif quantization == "mxfp8": + quantizer = MXFP8Quantizer(fp8_dtype=tex.DType.kFloat8E4M3) + elif quantization == "nvfp4": + quantizer = NVFP4Quantizer( + with_rht=False, + with_post_rht_amax=False, + with_2d_quantization=False, + stochastic_rounding=False, + with_random_sign_mask=False, + ) + else: + raise ValueError(f"Unknown quantization scheme: {quantization}") + + quantizer.internal = False + + return quantizer + + +def _get_rowwise_data_tensor(qtensor, quantization: str) -> torch.Tensor: + if quantization in ("fp8_delayed_scaling", "fp8_current_scaling"): + return qtensor._data + if quantization in ("fp8_blockwise", "mxfp8", "nvfp4"): + return qtensor._rowwise_data + raise ValueError(f"Unknown quantization scheme: {quantization}") + + +def _rowwise_offset_bytes(numel: int, quantization: str) -> int: + if quantization == "nvfp4": + return numel // 2 + return numel + + +class TestGroupedTensor: + @staticmethod + def setup_class(cls) -> None: + # Configure RNG + seed = 1234 + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + + def test_basic_construction_all_same_shape(self) -> None: + """Test GroupedTensor construction with all tensors having same shape""" + num_tensors = 4 + shape = [(256, 512) for _ in range(num_tensors)] + + grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes( + num_tensors=num_tensors, + shape=shape, + quantizer=None, + device="cuda", + dtype=torch.float32, + ) + + assert grouped_tensor.num_tensors == num_tensors + assert grouped_tensor.all_same_shape() + assert grouped_tensor.all_same_first_dim() + assert grouped_tensor.all_same_last_dim() + assert grouped_tensor.logical_shape == (num_tensors * 256, 512) + assert grouped_tensor.get_common_first_dim() == 256 + assert grouped_tensor.get_common_last_dim() == 512 + assert grouped_tensor.has_data() + + def test_basic_construction_varying_first_dim(self) -> None: + """Test GroupedTensor construction with varying first dimension""" + num_tensors = 3 + shape = [(128, 512), (256, 512), (384, 512)] + + grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes( + num_tensors=num_tensors, + shape=shape, + quantizer=None, + device="cuda", + dtype=torch.float32, + ) + + assert grouped_tensor.num_tensors == num_tensors + assert not grouped_tensor.all_same_shape() + assert not grouped_tensor.all_same_first_dim() + assert grouped_tensor.all_same_last_dim() + assert grouped_tensor.get_common_last_dim() == shape[0][1] + assert grouped_tensor.logical_shape == ( + sum(v for v, _ in shape), + shape[0][1], + ) # sum of first dims + + def test_split_into_quantized_tensors_no_quantization(self) -> None: + """Test split_into_quantized_tensors for unquantized tensors""" + num_tensors = 3 + shape = [(256, 512) for _ in range(num_tensors)] + + grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes( + num_tensors=num_tensors, + shape=shape, + quantizer=None, + device="cuda", + dtype=torch.float32, + ) + + # Get the original data pointer + original_data_ptr = grouped_tensor.data.data_ptr() + + # Split into tensors + tensors = grouped_tensor.split_into_quantized_tensors() + + assert len(tensors) == num_tensors + + # Verify each tensor has correct shape and shares storage + for i, tensor in enumerate(tensors): + assert tensor.shape == shape[i] + assert isinstance(tensor, torch.Tensor) + assert not hasattr(tensor, "_data") # Not a quantized tensor + + # Verify data pointer is within the original grouped tensor storage + # The tensor should be a view of the original data + assert tensor.data_ptr() >= original_data_ptr + + # Calculate expected offset + expected_offset = i * (shape[i][0] * shape[i][1]) * tensor.element_size() + assert tensor.data_ptr() == original_data_ptr + expected_offset + + @pytest.mark.parametrize("quantization", _quantization_params) + def test_split_into_quantized_tensors_quantized(self, quantization: str) -> None: + """Test split_into_quantized_tensors for quantized tensors""" + num_tensors = 3 + shape = [(512, 512) for _ in range(num_tensors)] + quantizers = make_quantizer(quantization, num_tensors, shape) + + grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes( + num_tensors=num_tensors, + shape=shape, + quantizer=quantizers, + device="cuda", + ) + + # Get the original data pointer + original_data_ptr = grouped_tensor.data.data_ptr() + + # Split into tensors + tensors = grouped_tensor.split_into_quantized_tensors() + + assert len(tensors) == num_tensors + + # Verify each tensor shares storage with the grouped tensor + for i, tensor in enumerate(tensors): + rowwise_data = _get_rowwise_data_tensor(tensor, quantization) + assert rowwise_data is not None + assert rowwise_data.data_ptr() >= original_data_ptr + numel = shape[i][0] * shape[i][1] + expected_offset = _rowwise_offset_bytes(i * numel, quantization) + assert rowwise_data.data_ptr() == original_data_ptr + expected_offset + + def test_split_varying_shapes(self) -> None: + """Test split_into_quantized_tensors with varying shapes""" + num_tensors = 3 + shape = [(128, 512), (256, 512), (384, 512)] + + grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes( + num_tensors=num_tensors, + shape=shape, + quantizer=None, + device="cuda", + dtype=torch.float32, + ) + + original_data_ptr = grouped_tensor.data.data_ptr() + tensors = grouped_tensor.split_into_quantized_tensors() + + assert len(tensors) == num_tensors + + # Verify shapes and storage + cumulative_offset = 0 + for i, tensor in enumerate(tensors): + assert tensor.shape == shape[i] + expected_offset = cumulative_offset * tensor.element_size() + assert tensor.data_ptr() == original_data_ptr + expected_offset + cumulative_offset += shape[i][0] * shape[i][1] + + @pytest.mark.parametrize("quantization", _quantization_params) + def test_quantize_inplace(self, quantization: str) -> None: + """Test that quantize is done in-place for all recipes""" + num_tensors = 3 + shape = [(512, 512) for _ in range(num_tensors)] + quantizers = make_quantizer(quantization, num_tensors, shape) + + grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes( + num_tensors=num_tensors, + shape=shape, + quantizer=quantizers, + device="cuda", + ) + + # Get original data pointers before quantization + original_data_ptr = grouped_tensor.data.data_ptr() + original_scale_inv_ptr = grouped_tensor.scale_inv.data_ptr() + original_scale_ptr = ( + grouped_tensor.scale.data_ptr() if grouped_tensor.scale is not None else None + ) + + # Create input tensors + input_tensors = [torch.randn(s, dtype=torch.float32, device="cuda") for s in shape] + + # Quantize in place + quantized_tensors = grouped_tensor.quantize(input_tensors) + + # Verify data pointers haven't changed (in-place operation) + assert grouped_tensor.data.data_ptr() == original_data_ptr + assert grouped_tensor.scale_inv.data_ptr() == original_scale_inv_ptr + if original_scale_ptr is not None: + assert grouped_tensor.scale.data_ptr() == original_scale_ptr + + # Verify returned tensors point to the same storage + for i, qtensor in enumerate(quantized_tensors): + rowwise_data = _get_rowwise_data_tensor(qtensor, quantization) + numel = shape[i][0] * shape[i][1] + expected_offset = _rowwise_offset_bytes(i * numel, quantization) + assert rowwise_data.data_ptr() == original_data_ptr + expected_offset + + @pytest.mark.parametrize("quantization", _quantization_params) + def test_quantize_varying_shapes(self, quantization: str) -> None: + """Test quantize with varying shapes""" + num_tensors = 3 + shape = [(256, 512), (512, 512), (768, 512)] + quantizers = make_quantizer(quantization, num_tensors, shape) + + grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes( + num_tensors=num_tensors, + shape=shape, + quantizer=quantizers, + device="cuda", + ) + + # Get original data pointers + original_data_ptr = grouped_tensor.data.data_ptr() + + # Create input tensors with varying shapes + input_tensors = [torch.randn(s, dtype=torch.float32, device="cuda") for s in shape] + + # Quantize in place + quantized_tensors = grouped_tensor.quantize(input_tensors) + + # Verify data pointer hasn't changed + assert grouped_tensor.data.data_ptr() == original_data_ptr + + # Verify each tensor points to correct location + cumulative_numel = 0 + for qtensor, tensor_shape in zip(quantized_tensors, shape): + rowwise_data = _get_rowwise_data_tensor(qtensor, quantization) + expected_offset = _rowwise_offset_bytes(cumulative_numel, quantization) + assert rowwise_data.data_ptr() == original_data_ptr + expected_offset + cumulative_numel += tensor_shape[0] * tensor_shape[1] + + @pytest.mark.parametrize("quantization", _quantization_params) + def test_static_quantize_method(self, quantization: str) -> None: + """Test the static quantize method""" + num_tensors = 3 + shape = [(512, 512) for _ in range(num_tensors)] + quantizers = make_quantizer(quantization, num_tensors, shape) + + # Create input tensors + input_tensors = [torch.randn(s, dtype=torch.float32, device="cuda") for s in shape] + + # Use static quantize method + grouped_tensor = GroupedTensor.create_and_quantize( + tensors=input_tensors, + quantizer=quantizers, + device="cuda", + ) + + # Verify the grouped tensor was created correctly + assert grouped_tensor.num_tensors == num_tensors + assert grouped_tensor.has_data() + + # Verify quantized_tensors were created and point to same storage + assert grouped_tensor.quantized_tensors is not None + assert len(grouped_tensor.quantized_tensors) == num_tensors + + original_data_ptr = grouped_tensor.data.data_ptr() + for i, qtensor in enumerate(grouped_tensor.quantized_tensors): + rowwise_data = _get_rowwise_data_tensor(qtensor, quantization) + numel = shape[i][0] * shape[i][1] + expected_offset = _rowwise_offset_bytes(i * numel, quantization) + assert rowwise_data.data_ptr() == original_data_ptr + expected_offset + + def test_clear(self) -> None: + """Test clear method""" + num_tensors = 3 + shape = [(256, 512) for _ in range(num_tensors)] + + grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes( + num_tensors=num_tensors, + shape=shape, + quantizer=None, + device="cuda", + dtype=torch.float32, + ) + + assert grouped_tensor.has_data() + assert grouped_tensor.num_tensors == num_tensors + + grouped_tensor.clear() + + assert not grouped_tensor.has_data() + assert grouped_tensor.num_tensors == 0 + assert grouped_tensor.data is None + assert grouped_tensor.logical_shape == (0, 0) diff --git a/tests/pytorch/test_sanity.py b/tests/pytorch/test_sanity.py index e9d24c1a8e..b94cbdcd96 100644 --- a/tests/pytorch/test_sanity.py +++ b/tests/pytorch/test_sanity.py @@ -2,7 +2,7 @@ # # See LICENSE for license information. -from typing import Optional +from typing import Optional, List import torch import pytest @@ -137,6 +137,117 @@ def reset_global_fp8_state(): FP8GlobalStateManager.reset() +def check_grouped_tensor_pointers_helper(tensors, num_elems_in_byte=1, tensor_name="tensor"): + """ + Verify that tensors are stored in contiguous memory. + + Args: + tensors: List or iterable of tensors to check + num_elems_in_byte: Number of elements packed per byte (1 for normal, 2 for NVFP4) + tensor_name: Name to use in error messages + """ + tensor_list = list(tensors) + if len(tensor_list) < 2: + return # Nothing to check + + for i in range(1, len(tensor_list)): + prev_tensor = tensor_list[i - 1] + curr_tensor = tensor_list[i] + + # Calculate expected offset based on previous tensor size + prev_numel = prev_tensor.numel() + expected_offset = (prev_numel // num_elems_in_byte) * prev_tensor.element_size() + + # Verify current tensor's data pointer is correctly offset + expected_ptr = prev_tensor.data_ptr() + expected_offset + actual_ptr = curr_tensor.data_ptr() + + assert ( + actual_ptr == expected_ptr + ), f"{tensor_name} {i} data pointer mismatch: expected {expected_ptr}, got {actual_ptr}" + + +def check_grouped_tensor_pointers( + weights: List[torch.Tensor], fp8_recipe: Optional[recipe.Recipe] = None +): + """ + Verify that the pointers of the weights are in contiguous memory for GroupedTensor. + TODO(ksivaman): This check can be made way more efficient but for now leaving the brute force approach. + """ + + num_elems_in_a_data_byte = 1 if fp8_recipe is None else 2 if fp8_recipe.nvfp4() else 1 + + # Check data. + if hasattr(weights[0], "_data") and weights[0]._data is not None: + data_tensors = [w._data for w in weights] + check_grouped_tensor_pointers_helper(data_tensors, num_elems_in_byte=1, tensor_name="data") + + # Check transpose. + if hasattr(weights[0], "_transpose") and weights[0]._transpose is not None: + transpose_tensors = [w._transpose for w in weights] + check_grouped_tensor_pointers_helper( + transpose_tensors, num_elems_in_byte=1, tensor_name="transpose" + ) + + # Check scale_inv. + if hasattr(weights[0], "_scale_inv") and weights[0]._scale_inv is not None: + scale_inv_tensors = [w._scale_inv for w in weights] + check_grouped_tensor_pointers_helper( + scale_inv_tensors, num_elems_in_byte=1, tensor_name="scale_inv" + ) + + # Check rowwise scale_inv. + if hasattr(weights[0], "_rowwise_scale_inv") and weights[0]._rowwise_scale_inv is not None: + scale_inv_tensors = [w._rowwise_scale_inv for w in weights] + check_grouped_tensor_pointers_helper( + scale_inv_tensors, num_elems_in_byte=1, tensor_name="rowwise_scale_inv" + ) + + # Check columnwise scale_inv. + if ( + hasattr(weights[0], "_columnwise_scale_inv") + and weights[0]._columnwise_scale_inv is not None + ): + columnwise_scale_inv_tensors = [w._columnwise_scale_inv for w in weights] + check_grouped_tensor_pointers_helper( + columnwise_scale_inv_tensors, + num_elems_in_byte=1, + tensor_name="columnwise scale_inv", + ) + + # Check rowwise amax. + if hasattr(weights[0], "_rowwise_amax") and weights[0]._rowwise_amax is not None: + rowwise_amax_tensors = [w._rowwise_amax for w in weights] + check_grouped_tensor_pointers_helper( + rowwise_amax_tensors, num_elems_in_byte=1, tensor_name="rowwise amax" + ) + + # Check columnwise amax. + if hasattr(weights[0], "_columnwise_amax") and weights[0]._columnwise_amax is not None: + columnwise_amax_tensors = [w._columnwise_amax for w in weights] + check_grouped_tensor_pointers_helper( + columnwise_amax_tensors, num_elems_in_byte=1, tensor_name="columnwise amax" + ) + + # Check rowwise data. + if hasattr(weights[0], "_rowwise_data") and weights[0]._rowwise_data is not None: + rowwise_data_tensors = [w._rowwise_data for w in weights] + check_grouped_tensor_pointers_helper( + rowwise_data_tensors, + num_elems_in_byte=num_elems_in_a_data_byte, + tensor_name="rowwise data", + ) + + # Check columnwise data. + if hasattr(weights[0], "_columnwise_data") and weights[0]._columnwise_data is not None: + columnwise_data_tensors = [w._columnwise_data for w in weights] + check_grouped_tensor_pointers_helper( + columnwise_data_tensors, + num_elems_in_byte=num_elems_in_a_data_byte, + tensor_name="columnwise data", + ) + + def _test_sanity_e2e_amp(block, dtype, config, fp8_recipe, skip_wgrad): te_inp_hidden_states = torch.randn( (config.max_seqlen_q, config.batch_size, config.hidden_size), @@ -495,9 +606,18 @@ def test_sanity_grouped_linear( use_fp8 = fp8_recipe is not None with quantized_model_init(enabled=use_fp8 and fp8_model_params, recipe=fp8_recipe): te_grouped_linear = GroupedLinear( - num_gemms, config.hidden_size, ffn_hidden_size, bias=use_bias, params_dtype=dtype + num_gemms, + config.hidden_size, + ffn_hidden_size, + bias=use_bias, + params_dtype=dtype, ).cuda() + # Verify that weights are stored in contiguous GroupedTensor storage. + weights = [getattr(te_grouped_linear, f"weight{i}") for i in range(num_gemms)] + if fp8_recipe is None or not (fp8_recipe.delayed() or fp8_recipe.float8_current_scaling()): + check_grouped_tensor_pointers(weights, fp8_recipe) + inp_hidden_states = torch.randn( num_tokens, config.hidden_size, dtype=dtype, requires_grad=True ).cuda() @@ -956,7 +1076,13 @@ def test_replace_raw_data_for_float8tensor(): random_bf16_data = torch.randn(fp8_tensor.shape, dtype=torch.bfloat16, device="cuda") fp8_quantizer.update_quantized(random_bf16_data, fp8_tensor) - attrs_to_check = ["_quantizer", "_fp8_dtype", "_scale_inv", "_transpose", "_transpose_invalid"] + attrs_to_check = [ + "_quantizer", + "_fp8_dtype", + "_scale_inv", + "_transpose", + "_transpose_invalid", + ] attrs = {} for attr in attrs_to_check: attrs[attr] = getattr(fp8_tensor, attr) diff --git a/transformer_engine/common/recipe/__init__.py b/transformer_engine/common/recipe/__init__.py index 64ee2a5a16..18577b0eb4 100644 --- a/transformer_engine/common/recipe/__init__.py +++ b/transformer_engine/common/recipe/__init__.py @@ -88,33 +88,40 @@ class Recipe: Base recipe class. """ - def nvfp4(self): + @classmethod + def nvfp4(cls): """Whether the given recipe is NVFP4 1D block scaling.""" - return isinstance(self, NVFP4BlockScaling) + return issubclass(cls, NVFP4BlockScaling) - def mxfp8(self): + @classmethod + def mxfp8(cls): """Whether the given recipe is MXFP8 block scaling.""" - return isinstance(self, MXFP8BlockScaling) + return issubclass(cls, MXFP8BlockScaling) - def delayed(self): + @classmethod + def delayed(cls): """Whether the given recipe is delayed scaling.""" - return isinstance(self, DelayedScaling) + return issubclass(cls, DelayedScaling) - def float8_current_scaling(self): + @classmethod + def float8_current_scaling(cls): """Whether the given recipe is (per-tensor) current scaling.""" - return isinstance(self, Float8CurrentScaling) + return issubclass(cls, Float8CurrentScaling) - def float8_per_tensor_scaling(self): + @classmethod + def float8_per_tensor_scaling(cls): """Whether the given recipe is per-tensor scaling.""" - return isinstance(self, (DelayedScaling, Float8CurrentScaling)) + return issubclass(cls, (DelayedScaling, Float8CurrentScaling)) - def float8_block_scaling(self): + @classmethod + def float8_block_scaling(cls): """Whether the given recipe is float8 blockwise scaling.""" - return isinstance(self, Float8BlockScaling) + return issubclass(cls, Float8BlockScaling) - def custom(self): + @classmethod + def custom(cls): """Whether the given recipe is custom.""" - return isinstance(self, CustomRecipe) + return issubclass(cls, CustomRecipe) @dataclass() diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index 6cb685a3f6..b6596bc2e9 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -13,6 +13,7 @@ import transformer_engine_torch as tex from transformer_engine.common.recipe import Recipe +from transformer_engine.pytorch.tensor.storage.grouped_tensor import GroupedTensor from .base import ( get_dummy_wgrad, TransformerEngineBaseModule, @@ -147,7 +148,10 @@ def forward( # tensors (like scales), but bulk allocation shares storage across all tensors, # so if scales can't be offloaded, nothing in the group can be offloaded. inputmats = tex.split_quantize( - inp_view, m_splits, input_quantizers, disable_bulk_allocation=cpu_offloading + inp_view, + m_splits, + input_quantizers, + disable_bulk_allocation=cpu_offloading, ) elif debug: inputmats = DebugQuantizer.multi_tensor_quantize( @@ -365,7 +369,10 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], for i in range(ctx.num_gemms): grad_biases[i] = grad_output_mats[i].sum(dim=0) grad_output = DebugQuantizer.multi_tensor_quantize( - grad_output_view, ctx.grad_output_quantizers, ctx.m_splits, ctx.activation_dtype + grad_output_view, + ctx.grad_output_quantizers, + ctx.m_splits, + ctx.activation_dtype, ) else: # Only split grad output. Grad bias is fused with @@ -436,7 +443,8 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], if ctx.input_quantizers[0] is not None: for input_quantizer in ctx.input_quantizers: if isinstance( - input_quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer) + input_quantizer, + (Float8Quantizer, Float8CurrentScalingQuantizer), ): input_quantizer.set_usage(rowwise=True, columnwise=True) else: @@ -446,7 +454,10 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], inputmats = tex.split_quantize(inp_view, ctx.m_splits, ctx.input_quantizers) elif ctx.debug: inputmats = DebugQuantizer.multi_tensor_quantize( - inp_view, ctx.input_quantizers, ctx.m_splits, ctx.activation_dtype + inp_view, + ctx.input_quantizers, + ctx.m_splits, + ctx.activation_dtype, ) else: inputmats = torch.split( @@ -616,7 +627,7 @@ def __init__( ) -> None: super().__init__(name) - params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype + self.params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype self.num_gemms = num_gemms self.in_features = in_features self.out_features = out_features @@ -631,12 +642,19 @@ def __init__( assert ( not ub_overlap_rs and not ub_overlap_ag ), "GroupedLinear doesn't support Userbuffer overlap." + self.init_method = init_method self.get_rng_state_tracker = get_rng_state_tracker self.rng_tracker_name = rng_tracker_name self.wgrad_store = WeightGradStore(delay_wgrad_compute) - self._offsets = {"input": 0, "weight": 1, "output": 2, "grad_output": 0, "grad_input": 1} + self._offsets = { + "input": 0, + "weight": 1, + "output": 2, + "grad_output": 0, + "grad_input": 1, + } self._num_fp8_tensors_per_gemm = { "fwd": 3, "bwd": 2, @@ -678,7 +696,7 @@ def __init__( self.out_features, self.in_features, device=device, - dtype=params_dtype, + dtype=self.params_dtype, ), ), init_fn=init_method, @@ -694,13 +712,13 @@ def __init__( torch.empty( self.out_features, device=device, - dtype=params_dtype, + dtype=self.params_dtype, ), ), init_fn=init_method_constant(0.0), ) else: - bias = torch.Tensor().to(dtype=params_dtype, device=device) + bias = torch.Tensor().to(dtype=self.params_dtype, device=device) setattr(self, f"bias{i}", bias) if self.primary_weights_in_fp8: @@ -724,8 +742,61 @@ def set_meta_tensor(self, fwd: bool, recipe: Recipe) -> None: if recipe.float8_current_scaling(): self._customize_quantizers_float8_current_scaling(fwd, recipe) + def make_grouped_weights(self, defer_init=False) -> None: + """ + Convert parameters into a GroupedTensor and re-register them as parameters. + """ + + if defer_init: + return + + weight_quantizers = self._get_weight_quantizers() + recipe = ( + weight_quantizers[0]._get_compatible_recipe() + if weight_quantizers and weight_quantizers[0] is not None + else None + ) + if recipe is not None and (recipe.delayed() or recipe.float8_current_scaling()): + self.set_tensor_parallel_attributes(defer_init=defer_init) + return + + weights = [getattr(self, f"weight{i}") for i in range(self.num_gemms)] + + # Create the weight storage. + grouped_weights = GroupedTensor.make_grouped_tensor_with_shapes( + num_tensors=self.num_gemms, + shape=[(self.out_features, self.in_features)] * self.num_gemms, + quantizer=weight_quantizers[0], + dtype=self.params_dtype, + device=weights[0].device, + ) + + # Copy existing params into storage. + with torch.no_grad(): + for i in range(self.num_gemms): + if self.primary_weights_in_fp8: + grouped_weights.quantized_tensors[i].copy_from_storage(weights[i]) + else: + grouped_weights.quantized_tensors[i].copy_(weights[i]) + + # Re-register the grouped weights as parameters. + for i in range(self.num_gemms): + self.register_parameter( + f"weight{i}", + torch.nn.Parameter(grouped_weights.quantized_tensors[i]), + init_fn=self.init_method, + get_rng_state_tracker=self.get_rng_state_tracker, + fp8_meta_index=self._offsets["weight"] + i * self._num_fp8_tensors_per_gemm["fwd"], + ) + + self.set_tensor_parallel_attributes(defer_init=defer_init) + def reset_parameters(self, defer_init=False): super().reset_parameters(defer_init=defer_init) + self.make_grouped_weights(defer_init=defer_init) + + def set_tensor_parallel_attributes(self, defer_init=False) -> None: + """Set attributes needed for TP""" if not defer_init: # Set parallelism attributes for linear weights @@ -925,7 +996,7 @@ def _get_weight_tensors(self) -> List[Union[torch.Tensor, QuantizedTensorStorage def _get_weight_quantizers(self) -> List[Quantizer]: """Get the weight quantizers of the module.""" - if not self.fp8 and not self.fp8_calibration: + if not self.fp8 and not self.fp8_calibration and not self.primary_weights_in_fp8: return [None] * self.num_gemms weight_quantizers = [ self.quantizers["scaling_fwd"][ @@ -934,7 +1005,7 @@ def _get_weight_quantizers(self) -> List[Quantizer]: for i in range(self.num_gemms) ] for i in range(self.num_gemms): - weight_quantizers[i].internal = True + weight_quantizers[i].internal = not self.primary_weights_in_fp8 return weight_quantizers def _get_quantizers(self): diff --git a/transformer_engine/pytorch/quantized_tensor.py b/transformer_engine/pytorch/quantized_tensor.py index 0a6ad61ff0..d78677bc83 100644 --- a/transformer_engine/pytorch/quantized_tensor.py +++ b/transformer_engine/pytorch/quantized_tensor.py @@ -69,7 +69,9 @@ def get_usages(self) -> Dict[str, bool]: f"{self.__class__.__name__} class does not implement get_usages function" ) - def prepare_for_saving(self) -> Tuple[list[Optional[torch.Tensor]], QuantizedTensorStorage]: + def prepare_for_saving( + self, + ) -> Tuple[list[Optional[torch.Tensor]], QuantizedTensorStorage]: """Prepare the tensor base for saving for backward""" raise NotImplementedError( f"{self.__class__.__name__} class does not implement prepare_for_saving function" @@ -115,11 +117,18 @@ def update_quantizer(self, quantizer: Quantizer): warnings.warn("Quantizer is being updated, this may affect model behavior") self._quantizer = quantizer + def copy_from_storage(self, src: QuantizedTensorStorage) -> None: + """Copy data from another QuantizedTensorStorage.""" + raise NotImplementedError( + f"{self.__class__.__name__} class does not implement copy_from_storage function" + ) + def prepare_for_saving( *tensors: Union[torch.Tensor, QuantizedTensorStorage], ) -> Tuple[ - list[Optional[Union[torch.Tensor, torch.nn.Parameter]]], list[Optional[QuantizedTensorStorage]] + list[Optional[Union[torch.Tensor, torch.nn.Parameter]]], + list[Optional[QuantizedTensorStorage]], ]: """Prepare tensors for saving. Needed because save_for_backward accepts only torch.Tensor/torch.nn.Parameter types, while we want to be able to save @@ -144,7 +153,10 @@ def restore_from_saved( return_saved_tensors: bool = False, ) -> ( list[Optional[torch.Tensor | QuantizedTensorStorage]] - | tuple[list[Optional[torch.Tensor | QuantizedTensorStorage]], list[Optional[torch.Tensor]]] + | tuple[ + list[Optional[torch.Tensor | QuantizedTensorStorage]], + list[Optional[torch.Tensor]], + ] ): """Recombine the tensor data and metadata during backward pass.""" tensor_objects = [] diff --git a/transformer_engine/pytorch/tensor/float8_tensor.py b/transformer_engine/pytorch/tensor/float8_tensor.py index 3aeace0a77..55bca49af3 100644 --- a/transformer_engine/pytorch/tensor/float8_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_tensor.py @@ -11,7 +11,11 @@ import transformer_engine_torch as tex from transformer_engine_torch import DType as TE_DType -from transformer_engine.common.recipe import DelayedScaling, Float8CurrentScaling, Recipe +from transformer_engine.common.recipe import ( + DelayedScaling, + Float8CurrentScaling, + Recipe, +) from ..utils import canonicalize_process_group, devices_match from .storage.float8_tensor_storage import Float8TensorStorage, _FromFloat8Func from ..quantized_tensor import QuantizedTensor, Quantizer @@ -154,6 +158,10 @@ def calibrate(self, tensor: torch.Tensor) -> None: amin, amax = tensor.aminmax() self.amax.copy_(torch.max(-amin, amax)) + def get_columnwise_shape(self, rowwise_data_shape: Iterable[int]) -> Tuple[int, ...]: + """Calculate the shape of the columnwise data for Float8 1D blockwise quantization.""" + return [rowwise_data_shape[-1]] + list(rowwise_data_shape[:-1]) + def create_tensor_from_data( self, data: torch.Tensor, @@ -408,6 +416,10 @@ def create_tensor_from_data( quantizer=self, ) + def get_columnwise_shape(self, rowwise_data_shape: Iterable[int]) -> Tuple[int, ...]: + """Calculate the shape of the columnwise data for Float8 1D blockwise quantization.""" + return [rowwise_data_shape[-1]] + list(rowwise_data_shape[:-1]) + def onnx_quantize(self, tensor: torch.Tensor) -> QuantizedTensor: """Function using primitives with ONNX defined translations.""" if tensor.dtype != torch.float32: @@ -769,7 +781,10 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): kwargs, ) return Float8Tensor.make_like( - tensor, data=func_out, data_transpose=func_transposed_out, shape=func_out.shape + tensor, + data=func_out, + data_transpose=func_transposed_out, + shape=func_out.shape, ) if func == torch.ops.aten.detach.default: diff --git a/transformer_engine/pytorch/tensor/mxfp8_tensor.py b/transformer_engine/pytorch/tensor/mxfp8_tensor.py index 8dd2255d89..41d6c87f2b 100644 --- a/transformer_engine/pytorch/tensor/mxfp8_tensor.py +++ b/transformer_engine/pytorch/tensor/mxfp8_tensor.py @@ -164,6 +164,49 @@ def calibrate(self, tensor: torch.Tensor) -> None: # TODO(ksivamani): No calibration needed for mxfp8? pass + def get_scale_shape( + self, + shape: Iterable[int], + columnwise: bool, + ) -> Tuple[int, int]: + """Calculate the shape of the scaling tensor for MXFP8 1D blockwise quantization. + + This method determines the shape of the scaling tensor needed for blockwise quantization, + taking into account the input tensor shape and whether columnwise scaling is used. + + Parameters + ---------- + shape : Iterable[int] + Shape of the input tensor to be quantized + columnwise : bool + Whether to use columnwise scaling (True) or rowwise scaling (False) + + Returns + ------- + Tuple[int, int] + Shape of the scaling tensor as (outer_dim, inner_dim) + For MXFP8 1D blockwise quantization, blocksize is 32 + Swizzle kernel will be performed before GEMM to suit the need of CuBLAS. + CuBLAS doc: https://docs.nvidia.com/cuda/cublas/index.html#d-block-scaling-factors-layout + """ + if columnwise: + # Columnwise: scale_inv shape is [prod(shape[:-1]) // BLOCK_SIZE, shape[-1]] + # with padding to multiples of [4, 128] + return ( + round_up_to_nearest_multiple(math.prod(shape[:-1]) // MXFP8_BLOCK_SCALING_SIZE, 4), + round_up_to_nearest_multiple(shape[-1], 128), + ) + # Rowwise: scale_inv shape is [prod(shape[:-1]), shape[-1] // BLOCK_SIZE] + # with padding to multiples of [128, 4] + return ( + round_up_to_nearest_multiple(math.prod(shape[:-1]), 128), + round_up_to_nearest_multiple(shape[-1] // MXFP8_BLOCK_SCALING_SIZE, 4), + ) + + def get_columnwise_shape(self, rowwise_data_shape: Tuple[int, ...]) -> Tuple[int, ...]: + """Calculate the shape of the columnwise data for MXFP8 1D blockwise quantization.""" + return rowwise_data_shape + def create_tensor_from_data( self, data: torch.Tensor, @@ -704,7 +747,7 @@ def fsdp_post_all_gather( columnwise_scale_inv=columnwise_scale_inv, fp8_dtype=fp8_dtype, dtype=param_dtype, - shape=rowwise_data.shape if rowwise_data is not None else columnwise_data.shape, + shape=(rowwise_data.shape if rowwise_data is not None else columnwise_data.shape), quantizer=self._quantizer, with_gemm_swizzled_scales=False, ) diff --git a/transformer_engine/pytorch/tensor/nvfp4_tensor.py b/transformer_engine/pytorch/tensor/nvfp4_tensor.py index 101cf78a8f..66f986a900 100644 --- a/transformer_engine/pytorch/tensor/nvfp4_tensor.py +++ b/transformer_engine/pytorch/tensor/nvfp4_tensor.py @@ -341,7 +341,10 @@ def make_empty( ) columnwise_scale_shape = self.get_scale_shape(shape, columnwise=True) columnwise_scale_inv = torch.empty( - columnwise_scale_shape, dtype=torch.uint8, device=device, pin_memory=pin_memory + columnwise_scale_shape, + dtype=torch.uint8, + device=device, + pin_memory=pin_memory, ) amax_columnwise = torch.zeros( 1, dtype=torch.float32, device=device, pin_memory=pin_memory diff --git a/transformer_engine/pytorch/tensor/storage/__init__.py b/transformer_engine/pytorch/tensor/storage/__init__.py index d7a2719200..7c8a014c1d 100644 --- a/transformer_engine/pytorch/tensor/storage/__init__.py +++ b/transformer_engine/pytorch/tensor/storage/__init__.py @@ -7,3 +7,4 @@ from .mxfp8_tensor_storage import MXFP8TensorStorage # noqa: F401 from .float8_blockwise_tensor_storage import Float8BlockwiseQTensorStorage # noqa: F401 from .nvfp4_tensor_storage import NVFP4TensorStorage # noqa: F401 +from .grouped_tensor import GroupedTensor # noqa: F401 diff --git a/transformer_engine/pytorch/tensor/storage/float8_blockwise_tensor_storage.py b/transformer_engine/pytorch/tensor/storage/float8_blockwise_tensor_storage.py index 278d7dc039..4cd6d19cd8 100644 --- a/transformer_engine/pytorch/tensor/storage/float8_blockwise_tensor_storage.py +++ b/transformer_engine/pytorch/tensor/storage/float8_blockwise_tensor_storage.py @@ -73,6 +73,24 @@ def clear(self): if t is not None: t.data = _empty_tensor() + def copy_from_storage(self, src: QuantizedTensorStorage) -> None: + """Copy data buffers from another Float8BlockwiseQTensorStorage.""" + if not isinstance(src, Float8BlockwiseQTensorStorage): + raise TypeError("copy_from_storage expects Float8BlockwiseQTensorStorage") + if self._fp8_dtype != src._fp8_dtype: + raise RuntimeError("FP8 dtype mismatch in copy_from_storage") + if self._is_2D_scaled != src._is_2D_scaled: + raise RuntimeError("Scale layout mismatch in copy_from_storage") + + def _copy_optional(dst: Optional[torch.Tensor], src_tensor: Optional[torch.Tensor]): + if dst is not None and src_tensor is not None: + dst.copy_(src_tensor) + + _copy_optional(self._rowwise_data, src._rowwise_data) + _copy_optional(self._columnwise_data, src._columnwise_data) + _copy_optional(self._rowwise_scale_inv, src._rowwise_scale_inv) + _copy_optional(self._columnwise_scale_inv, src._columnwise_scale_inv) + def get_metadata(self) -> Dict[str, Any]: """Get this tensor's metadata.""" return { diff --git a/transformer_engine/pytorch/tensor/storage/float8_tensor_storage.py b/transformer_engine/pytorch/tensor/storage/float8_tensor_storage.py index adf3ce8aea..9adb86c453 100644 --- a/transformer_engine/pytorch/tensor/storage/float8_tensor_storage.py +++ b/transformer_engine/pytorch/tensor/storage/float8_tensor_storage.py @@ -104,6 +104,24 @@ def clear(self): t.data = _empty_tensor() self._transpose_invalid = True + def copy_from_storage(self, src: QuantizedTensorStorage) -> None: + """Copy data buffers from another Float8TensorStorage.""" + if not isinstance(src, Float8TensorStorage): + raise TypeError("copy_from_storage expects Float8TensorStorage") + if self._fp8_dtype != src._fp8_dtype: + raise RuntimeError("FP8 dtype mismatch in copy_from_storage") + + def _copy_optional( + dst: Optional[torch.Tensor], + src_tensor: Optional[torch.Tensor], + ): + if dst is not None and src_tensor is not None: + dst.copy_(src_tensor) + + _copy_optional(self._data, src._data) + _copy_optional(self._transpose, src._transpose) + _copy_optional(self._scale_inv, src._scale_inv) + def get_metadata(self) -> Dict[str, Any]: """Get this tensor's metadata.""" return { diff --git a/transformer_engine/pytorch/tensor/storage/grouped_tensor.py b/transformer_engine/pytorch/tensor/storage/grouped_tensor.py new file mode 100644 index 0000000000..dad4d1d0ea --- /dev/null +++ b/transformer_engine/pytorch/tensor/storage/grouped_tensor.py @@ -0,0 +1,942 @@ +# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Grouped tensor class for handling collections of tensors with different shapes""" +from __future__ import annotations +from typing import Optional, Tuple, List, Union +import math + +import torch + +from ...quantized_tensor import QuantizedTensorStorage, Quantizer + +from ..mxfp8_tensor import MXFP8Tensor +from ..nvfp4_tensor import NVFP4Tensor +from ..float8_tensor import Float8Tensor +from ..float8_blockwise_tensor import Float8BlockwiseQTensor +from .float8_tensor_storage import Float8TensorStorage +from .mxfp8_tensor_storage import MXFP8TensorStorage +from .float8_blockwise_tensor_storage import Float8BlockwiseQTensorStorage +from .nvfp4_tensor_storage import NVFP4TensorStorage + + +class GroupedTensor: + """ + EXPERIMENTAL FEATURE AND SUBJECT TO CHANGE. + + Grouped tensor is a collection of tensors with different shapes but the same dtype and scaling mode. + + Shape Representation: + - logical_shape: 2D shape representing the conceptual layout, i.e. the shape when member tensors + are flattened to 2D and stacked together (REQUIRED) + + When all_same_shape(): [num_tensors * M, N] where each tensor is (M, N) + + When varying_first_dim(): [~sum_of_first_dims, N] where N is common + + When varying_last_dim(): [M, ~sum_of_last_dims] where M is common + + When varying_both_dims(): [1, total_elements] (fully flattened) + + - first_dims and last_dims are OPTIONAL (None if dimension is uniform) + + None first_dims: all tensors have the same first dimension + + None last_dims: all tensors have the same last dimension + + Both None: all tensors have identical shapes + + Both set: each tensor has unique shape (first_dims[i], last_dims[i]) + + Data Layout: + - ALL data fields are stored as 1D flattened arrays (data, columnwise_data, scale_inv, etc.) + - logical_shape provides the conceptual 2D interpretation + - All data is stored on device in contiguous layout + + Note: This structure is used only for combined storage of multiple tensors with the same dtype and scaling mode. + """ + + def __init__( + self, + num_tensors: int, + shape: List[Tuple[int, int]], + quantizer: Optional[Quantizer] = None, + dtype: Optional[torch.dtype] = None, + data: Optional[torch.Tensor] = None, + columnwise_data: Optional[torch.Tensor] = None, + scale_inv: Optional[torch.Tensor] = None, + columnwise_scale_inv: Optional[torch.Tensor] = None, + amax: Optional[torch.Tensor] = None, + columnwise_amax: Optional[torch.Tensor] = None, + scale: Optional[torch.Tensor] = None, + first_dims: Optional[torch.Tensor] = None, + last_dims: Optional[torch.Tensor] = None, + tensor_offsets: Optional[torch.Tensor] = None, + offsets: Optional[List[int]] = None, + scale_inv_offsets: Optional[List[int]] = None, + columnwise_scale_inv_offsets: Optional[List[int]] = None, + logical_shape: Optional[Tuple[int, int]] = None, + ) -> None: + """ + Initialize a GroupedTensor. + + Args: + num_tensors: Number of tensors in the group + shape: 2D shape of each tensor (len num_tensors) + quantizer: Quantizer for the grouped tensor + data: Row-wise data buffer (1D flattened) + columnwise_data: Column-wise data buffer (1D flattened) + scale_inv: Row-wise scale inverse buffer + columnwise_scale_inv: Column-wise scale inverse buffer + amax: Row-wise amax buffer + columnwise_amax: Column-wise amax buffer + scale: Scale buffer (for FP8-DS only) + first_dims: Device tensor of int64 array of length num_tensors (or None if uniform) + last_dims: Device tensor of int64 array of length num_tensors (or None if uniform) + tensor_offsets: Device tensor of int64 array of length num_tensors (or None if uniform) + offsets: Vector of integer offsets for each tensor. + logical_shape: 2D tuple representing conceptual shape + """ + self.num_tensors = num_tensors + self.quantizer = quantizer + self.shape = shape + self.dtype = ( + dtype if dtype is not None else torch.float32 + ) # Default to float32 if not provided + + # Data buffers + self.data = data + self.columnwise_data = columnwise_data + self.scale_inv = scale_inv + self.columnwise_scale_inv = columnwise_scale_inv + self.amax = amax + self.columnwise_amax = columnwise_amax + self.scale = scale + + # For convenient indexing for python GroupedTensor API. + self.scale_inv_offsets = scale_inv_offsets + self.columnwise_scale_inv_offsets = columnwise_scale_inv_offsets + + # Shape information (OPTIONAL - None if dimension is uniform across all tensors) + # first_dims[i] = first dimension of tensor i (None if all tensors have same first dim) + # last_dims[i] = last dimension of tensor i (None if all tensors have same last dim) + self.first_dims = ( + first_dims # Device pointer to int64_t array of length num_tensors (or None) + ) + self.last_dims = ( + last_dims # Device pointer to int64_t array of length num_tensors (or None) + ) + + # Offsets for indexing into contiguous 1D layout (OPTIONAL - not needed if all_same_shape()) + # tensor_offsets[i] = element offset to start of tensor i (cumulative sum of numel for tensors 0..i-1) + # Usage: tensor_i_ptr = data.data_ptr() + tensor_offsets[i] * element_size + # If None and all_same_shape(): offset[i] = i * M * N (where M, N are common dimensions) + self.tensor_offsets = ( + tensor_offsets # Device pointer to int64_t array of length num_tensors (or None) + ) + self.offsets = offsets # Vector of integer offsets for each tensor. + + # Logical shape: conceptual 2D shape of the grouped data (REQUIRED) + # Represents how the 1D flattened data should be interpreted as 2D + # Always 2D with positive dimensions + self.logical_shape = logical_shape if logical_shape is not None else (0, 0) + + # Hold a reference to the quantized tensors that occupy same storage as the GroupedTensor. + # Used as a convenience. + self.quantized_tensors = None + + def has_data(self) -> bool: + """ + Check if the tensor has row-wise data. + + Returns: + True if data buffer is initialized, False otherwise + """ + return self.data is not None + + def has_columnwise_data(self) -> bool: + """ + Check if the tensor has column-wise data. + + Returns: + True if columnwise_data buffer is initialized, False otherwise + """ + return self.columnwise_data is not None + + def all_same_first_dim(self) -> bool: + """ + Check if all tensors in the group have the same first dimension. + + Returns: + True if first dimension is uniform across all tensors + """ + return self.first_dims is None + + def all_same_last_dim(self) -> bool: + """ + Check if all tensors in the group have the same last dimension. + + Returns: + True if last dimension is uniform across all tensors + """ + return self.last_dims is None + + def all_same_shape(self) -> bool: + """ + Check if all tensors in the group have identical shapes. + + Returns: + True if all tensors have the same shape + """ + return self.first_dims is None and self.last_dims is None + + def varying_both_dims(self) -> bool: + """ + Check if both dimensions vary across tensors. + + Returns: + True if both first and last dimensions vary + """ + return self.first_dims is not None and self.last_dims is not None + + def get_common_first_dim(self) -> int: + """ + Get the common first dimension when all tensors share it. + + Returns: + The common first dimension + + Raises: + RuntimeError: If first dimension varies across tensors or logical_shape is not 2D + """ + if not self.all_same_first_dim(): + raise RuntimeError("First dim varies across tensors") + if len(self.logical_shape) != 2: + raise RuntimeError("Logical shape must be 2D") + + if self.all_same_shape(): + # When both dims are uniform: logical_shape = [num_tensors * M, N] + return self.logical_shape[0] // self.num_tensors + # When varying last dims but not first dim: logical_shape = [M, sum_of_last_dims] + return self.logical_shape[0] + + def get_common_last_dim(self) -> int: + """ + Get the common last dimension when all tensors share it. + + Returns: + The common last dimension + + Raises: + RuntimeError: If last dimension varies across tensors or logical_shape is not 2D + """ + if not self.all_same_last_dim(): + raise RuntimeError("Last dim varies across tensors") + if len(self.logical_shape) != 2: + raise RuntimeError("Logical shape must be 2D") + + # For both uniform and varying first dim cases: logical_shape[1] is the common last dim + return self.logical_shape[1] + + def get_dtype(self) -> torch.dtype: + """ + Get the high precision data type of the tensor. + + Returns: + The high precision dtype of the data buffer + """ + + return self.dtype + + def clear(self) -> None: + """ + Reset tensor data and clear all buffers. + """ + self.data = None + self.columnwise_data = None + self.scale_inv = None + self.columnwise_scale_inv = None + self.amax = None + self.columnwise_amax = None + self.scale = None + self.first_dims = None + self.last_dims = None + self.tensor_offsets = None + self.logical_shape = (0, 0) + self.num_tensors = 0 + self.quantizer = None + self.quantized_tensors = None + self.offsets = None + self.scale_inv_offsets = None + self.columnwise_scale_inv_offsets = None + + def __repr__(self) -> str: + """String representation of the GroupedTensor.""" + return ( + f"GroupedTensor(num_tensors={self.num_tensors}, " + f"shape={self.shape}, " + f"logical_shape={self.logical_shape}, " + f"dtype={self.get_dtype()})" + ) + + def __str__(self) -> str: + """User-friendly string representation.""" + shape_info = [] + if self.all_same_shape(): + shape_info.append("uniform shape") + else: + if not self.all_same_first_dim(): + shape_info.append("varying first dim") + if not self.all_same_last_dim(): + shape_info.append("varying last dim") + + return ( + f"GroupedTensor with {self.num_tensors} tensors " + f"({', '.join(shape_info) if shape_info else 'uniform'}), " + f"logical_shape={self.logical_shape}, " + f"dtype={self.get_dtype()}" + ) + + @staticmethod + def make_grouped_tensor_with_shapes( + num_tensors: int, + shape: List[Tuple[int, int]], + quantizer: Optional[Quantizer] = None, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ) -> GroupedTensor: + """ + Create a GroupedTensor for storing multiple weight tensors of the same shape. + + Args: + num_tensors: Number of tensors + shape: 2D shape of each tensor (len num_tensors) + quantizer: Quantizer for each tensor + device: Device to allocate tensors on, defaults to current cuda device + dtype: Data type of the tensor (for high precision case) + + Returns: + A GroupedTensor. + """ + + # First dim + first_dim_list = [s[0] for s in shape] + uniform_first_dim = all(first_dim_list[0] == x for x in first_dim_list) + logical_first_dim = sum(first_dim_list) + if uniform_first_dim: + first_dims = None + else: + first_dims = torch.tensor([s[0] for s in shape], dtype=torch.int64, device=device) + + # Last dim + last_dim_list = [s[1] for s in shape] + logical_last_dim = last_dim_list[0] + assert all(logical_last_dim == x for x in last_dim_list), "Last dims should be uniform" + + return GroupedTensor.make_grouped_tensor( + num_tensors=num_tensors, + first_dims=first_dims, + last_dims=None, + logical_first_dim=logical_first_dim, + logical_last_dim=logical_last_dim, + quantizer=quantizer, + device=device, + dtype=dtype, + ) + + @staticmethod + def make_grouped_tensor( + num_tensors: int, + first_dims: Optional[torch.Tensor], + last_dims: Optional[torch.Tensor], + logical_first_dim: int, + logical_last_dim: int, + quantizer: Optional[Quantizer] = None, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + ) -> GroupedTensor: + """ + Create a GroupedTensor for storing multiple weight tensors of the same shape. + + Args: + num_tensors: Number of tensors + first_dims: Device tensor of int64 array of length num_tensors (or None if uniform) + last_dims: Device tensor of int64 array of length num_tensors (or None if uniform) + logical_first_dim: Logical first dimension + logical_last_dim: Logical last dimension + quantizer: Quantizer for each tensor + Used to figure out the recipe and what to allocate. + device: Device to allocate tensors on, defaults to current cuda device + dtype: Data type of the tensor (for high precision case) + + Returns: + A GroupedTensor. + """ + + # Set device + if device is None: + device = torch.cuda.current_device() + + # Shape patterns and validation. + all_same_first = first_dims is None + all_same_last = last_dims is None + + assert all_same_last, "Last dim must be uniform for GroupedTensor" + assert logical_first_dim > 0, "Logical first dim must be positive for GroupedTensor" + assert logical_last_dim > 0, "Logical last dim must be positive for GroupedTensor" + + # assert ( + # logical_first_dim % 128 == 0 + # ), "Logical first dim must be divisible by 128" + # assert logical_last_dim % 128 == 0, "Logical last dim must be divisible by 128" + + # Calculate tensor offsets (cumulative element offsets) + tensor_offsets = None + offsets = None + shape = [] + if not all_same_first: + # Need explicit offsets for non-uniform shapes + # Offsets are based on number of elements and not pointers. + # Kernels need to calculate precise pointers based on size of elements. + + # TODO(ksivaman): Single kernel + remove the host offset calculation. + tensor_offsets = torch.cat( + [ + torch.zeros(1, device=first_dims.device, dtype=first_dims.dtype), + torch.cumsum(first_dims * logical_last_dim, dim=0), + ] + ) + offsets = tensor_offsets.tolist() + first_dims_list = first_dims.tolist() + for i in range(num_tensors): + shape.append((first_dims_list[i], logical_last_dim)) + else: + offsets = [ + i * logical_first_dim * logical_last_dim // num_tensors + for i in range(num_tensors + 1) + ] + for i in range(num_tensors): + shape.append((logical_first_dim // num_tensors, logical_last_dim)) + + # Calculate logical shape based + logical_shape = (logical_first_dim, logical_last_dim) + + no_quantization = quantizer is None + + rowwise_usage = quantizer.rowwise_usage if not no_quantization else True + columnwise_usage = quantizer.columnwise_usage if not no_quantization else False + + # Calculate total elements across all tensors + total_elements = logical_first_dim * logical_last_dim + + data = None + columnwise_data = None + scale_inv = None + columnwise_scale_inv = None + amax = None + columnwise_amax = None + scale = None + scale_inv_offsets = None + columnwise_scale_inv_offsets = None + if no_quantization: + assert dtype is not None, "dtype must be provided for unquantized GroupedTensor" + if rowwise_usage: + # Allocate rowwise data buffer (1D flattened, uint8) + data = torch.empty(total_elements, dtype=dtype, device=device) + + if columnwise_usage: + # Allocate columnwise data buffer (1D flattened, uint8) + columnwise_data = torch.empty(total_elements, dtype=dtype, device=device) + elif quantizer._get_compatible_recipe().mxfp8(): + if rowwise_usage: + # Allocate rowwise data buffer (1D flattened, uint8) + data = torch.empty(total_elements, dtype=torch.uint8, device=device) + # Scale inverse buffer for MXFP8 - complex shape based on block scaling + # For grouped tensors, we need to calculate scale_inv size for all tensors + total_scale_elements = 0 + scale_inv_offsets = [0] + for i, s in enumerate(shape): + scale_inv_shape = quantizer.get_scale_shape(s, False) + scale_elements = math.prod(scale_inv_shape) + total_scale_elements += scale_elements + if i < num_tensors - 1: + scale_inv_offsets.append(total_scale_elements) + scale_inv = torch.empty(total_scale_elements, dtype=torch.uint8, device=device) + + if columnwise_usage: + # Allocate columnwise data buffer (1D flattened, uint8) + columnwise_data = torch.empty(total_elements, dtype=torch.uint8, device=device) + # Columnwise scale inverse buffer + total_columnwise_scale_elements = 0 + columnwise_scale_inv_offsets = [0] + for i, s in enumerate(shape): + scale_inv_shape = quantizer.get_scale_shape(s, False) + columnwise_scale_elements = math.prod(scale_inv_shape) + total_columnwise_scale_elements += columnwise_scale_elements + if i < num_tensors - 1: + columnwise_scale_inv_offsets.append(total_columnwise_scale_elements) + columnwise_scale_inv = torch.empty( + total_columnwise_scale_elements, dtype=torch.uint8, device=device + ) + elif quantizer._get_compatible_recipe().delayed(): + if rowwise_usage: + # Allocate rowwise data buffer (1D flattened, uint8) + data = torch.empty(total_elements, dtype=torch.uint8, device=device) + # Scale inverse - one per tensor + scale_inv = torch.empty(num_tensors, dtype=torch.float32, device=device) + # One scale per tensor, so offsets are simply 0, 1, 2, ..., num_tensors-1 + scale_inv_offsets = list(range(num_tensors)) + + if columnwise_usage: + # Allocate columnwise data buffer (1D flattened, uint8) + columnwise_data = torch.empty(total_elements, dtype=torch.uint8, device=device) + # Columnwise scale inverse - one per tensor + columnwise_scale_inv = torch.empty(num_tensors, dtype=torch.float32, device=device) + # One scale per tensor, so offsets are simply 0, 1, 2, ..., num_tensors-1 + columnwise_scale_inv_offsets = list(range(num_tensors)) + + # Amax buffer for delayed scaling - one per tensor + amax = torch.empty(num_tensors, dtype=torch.float32, device=device) + elif quantizer._get_compatible_recipe().nvfp4(): + + if rowwise_usage: + # Allocate rowwise data buffer (1D flattened, uint8, but FP4 packs 2 values per byte) + data = torch.empty((total_elements) // 2, dtype=torch.uint8, device=device) + # Scale inverse buffer for NVFP4 - complex shape based on block scaling + # For simplicity, calculate total scale elements needed + total_scale_elements = 0 + scale_inv_offsets = [0] + for i, s in enumerate(shape): + scale_inv_shape = quantizer.get_scale_shape(s, False) + total_scale_elements += math.prod(scale_inv_shape) + if i < num_tensors - 1: + scale_inv_offsets.append(total_scale_elements) + scale_inv = torch.empty(total_scale_elements, dtype=torch.uint8, device=device) + # Amax buffer - one per tensor + amax = torch.empty(num_tensors, dtype=torch.float32, device=device) + + if columnwise_usage: + # Allocate columnwise data buffer (1D flattened, uint8, FP4 packed) + columnwise_data = torch.empty( + (total_elements) // 2, dtype=torch.uint8, device=device + ) + # Columnwise scale inverse buffer + total_columnwise_scale_elements = 0 + columnwise_scale_inv_offsets = [0] + for i, s in enumerate(shape): + columnwise_scale_inv_shape = quantizer.get_scale_shape(s, True) + total_columnwise_scale_elements += math.prod(columnwise_scale_inv_shape) + if i < num_tensors - 1: + columnwise_scale_inv_offsets.append(total_columnwise_scale_elements) + columnwise_scale_inv = torch.empty( + total_columnwise_scale_elements, dtype=torch.uint8, device=device + ) + # Columnwise amax buffer - one per tensor + columnwise_amax = torch.empty(num_tensors, dtype=torch.float32, device=device) + elif quantizer._get_compatible_recipe().float8_block_scaling(): + if rowwise_usage: + # Allocate rowwise data buffer (1D flattened, uint8) + data = torch.empty(total_elements, dtype=torch.uint8, device=device) + # Scale inverse - size depends on block configuration + # For simplicity, calculate total scale elements needed + total_scale_elements = 0 + scale_inv_offsets = [0] + for i, s in enumerate(shape): + scale_inv_shape = quantizer.get_scale_shape(s, False) + total_scale_elements += math.prod(scale_inv_shape) + if i < num_tensors - 1: + scale_inv_offsets.append(total_scale_elements) + scale_inv = torch.empty(total_scale_elements, dtype=torch.float32, device=device) + + if columnwise_usage: + # Allocate columnwise data buffer (1D flattened, uint8) + columnwise_data = torch.empty(total_elements, dtype=torch.uint8, device=device) + # Columnwise scale inverse + total_columnwise_scale_elements = 0 + columnwise_scale_inv_offsets = [0] + for i, s in enumerate(shape): + columnwise_scale_inv_shape = quantizer.get_scale_shape(s, True) + total_columnwise_scale_elements += math.prod(columnwise_scale_inv_shape) + if i < num_tensors - 1: + columnwise_scale_inv_offsets.append(total_columnwise_scale_elements) + columnwise_scale_inv = torch.empty( + total_columnwise_scale_elements, dtype=torch.float32, device=device + ) + elif quantizer._get_compatible_recipe().float8_current_scaling(): + # Current scaling - per-tensor scaling computed on the fly + if rowwise_usage: + # Allocate rowwise data buffer (1D flattened, uint8) + data = torch.empty(total_elements, dtype=torch.uint8, device=device) + # Scale inverse - one per tensor + scale_inv = torch.empty(num_tensors, dtype=torch.float32, device=device) + # One scale per tensor, so offsets are simply 0, 1, 2, ..., num_tensors-1 + scale_inv_offsets = list(range(num_tensors)) + + if columnwise_usage: + # Allocate columnwise data buffer (1D flattened, uint8) + columnwise_data = torch.empty(total_elements, dtype=torch.uint8, device=device) + # Columnwise scale inverse - one per tensor + columnwise_scale_inv = torch.empty(num_tensors, dtype=torch.float32, device=device) + # One scale per tensor, so offsets are simply 0, 1, 2, ..., num_tensors-1 + columnwise_scale_inv_offsets = list(range(num_tensors)) + + # Scale and amax buffers for current scaling - one per tensor + scale = torch.empty(num_tensors, dtype=torch.float32, device=device) + amax = torch.empty(num_tensors, dtype=torch.float32, device=device) + else: + raise ValueError(f"Unsupported quantizer for GroupedTensor: {quantizer}") + + grouped_tensor = GroupedTensor( + num_tensors=num_tensors, + shape=shape, + dtype=dtype, + quantizer=quantizer, + data=data, + columnwise_data=columnwise_data, + scale_inv=scale_inv, + columnwise_scale_inv=columnwise_scale_inv, + amax=amax, + columnwise_amax=columnwise_amax, + scale=scale, + first_dims=first_dims, + last_dims=last_dims, + tensor_offsets=tensor_offsets, + offsets=offsets, + scale_inv_offsets=scale_inv_offsets, + columnwise_scale_inv_offsets=columnwise_scale_inv_offsets, + logical_shape=logical_shape, + ) + + grouped_tensor.quantized_tensors = grouped_tensor.split_into_quantized_tensors() + return grouped_tensor + + def split_into_quantized_tensors( + self, + ) -> List[Union[QuantizedTensorStorage, torch.Tensor]]: + """ + Split the GroupedTensor into a list of `num_tensors` + quantized tensors based on the quantizer. No additional memory allocation is performed, + so the tensors returned are the same as the ones used to create the GroupedTensor. + + If quantizer is None, returns normal torch tensors. + If quantizer.internal is True, returns QuantizedTensorStorage. + Otherwise, returns QuantizedTensor. + + TODO(ksivaman): Block cases where any dims are varying. This is needed only + to expose the weights as separate parameters. + """ + + result = [] + + no_quantization = self.quantizer is None + + # Case 1: No quantization - return regular torch tensors + if no_quantization: + for i in range(self.num_tensors): + # Get tensor shape + tensor_shape = self.shape[i] + + # Get tensor data slice + if self.offsets is not None: + start_offset = self.offsets[i] + numel = tensor_shape[0] * tensor_shape[1] + end_offset = start_offset + numel + + if self.has_data(): + tensor_data = self.data[start_offset:end_offset].view(tensor_shape) + result.append(tensor_data) + elif self.has_columnwise_data(): + tensor_data = self.columnwise_data[start_offset:end_offset].view( + tensor_shape + ) + result.append(tensor_data) + else: + raise RuntimeError("GroupedTensor has no data to split") + else: + # All same shape case + numel = tensor_shape[0] * tensor_shape[1] + start_offset = i * numel + end_offset = start_offset + numel + + if self.has_data(): + tensor_data = self.data[start_offset:end_offset].view(tensor_shape) + result.append(tensor_data) + elif self.has_columnwise_data(): + tensor_data = self.columnwise_data[start_offset:end_offset].view( + tensor_shape + ) + result.append(tensor_data) + else: + raise RuntimeError("GroupedTensor has no data to split") + + return result + + # Case 2: Quantized tensors + recipe = self.quantizer._get_compatible_recipe() + + for i in range(self.num_tensors): + # Get tensor shape + tensor_shape = self.shape[i] + numel = tensor_shape[0] * tensor_shape[1] + + # Get data offsets + if self.offsets is not None: + data_start = self.offsets[i] + data_end = data_start + numel + else: + # All same shape + data_start = i * numel + data_end = data_start + numel + + # Special shape handling for NVFP4. + nvfp4 = self.quantizer._get_compatible_recipe().nvfp4() + if nvfp4: + data_start = data_start // 2 + data_end = data_end // 2 + + # Extract rowwise and columnwise data + rowwise_data = None + columnwise_data = None + + if self.has_data(): + if nvfp4: + rowwise_tensor_shape = self.quantizer.convert_shape_for_fp4(tensor_shape) + else: + rowwise_tensor_shape = tensor_shape + rowwise_data = self.data[data_start:data_end].view(rowwise_tensor_shape) + + if self.has_columnwise_data(): + columnwise_tensor_shape = self.quantizer.get_columnwise_shape(tensor_shape) + if nvfp4: + columnwise_tensor_shape = self.quantizer.convert_shape_for_fp4( + columnwise_tensor_shape + ) + columnwise_data = self.columnwise_data[data_start:data_end].view( + columnwise_tensor_shape + ) + + # MXFP8 format + if recipe.mxfp8(): + # Extract scale_inv data + rowwise_scale_inv = None + columnwise_scale_inv = None + + if self.scale_inv is not None and self.scale_inv_offsets is not None: + scale_start = self.scale_inv_offsets[i] + if i < self.num_tensors - 1: + scale_end = self.scale_inv_offsets[i + 1] + else: + scale_end = self.scale_inv.numel() + + # Calculate expected scale shape for MXFP8 + scale_shape = self.quantizer.get_scale_shape(tensor_shape, False) + rowwise_scale_inv = self.scale_inv[scale_start:scale_end].view(scale_shape) + + if ( + self.columnwise_scale_inv is not None + and self.columnwise_scale_inv_offsets is not None + ): + cscale_start = self.columnwise_scale_inv_offsets[i] + if i < self.num_tensors - 1: + cscale_end = self.columnwise_scale_inv_offsets[i + 1] + else: + cscale_end = self.columnwise_scale_inv.numel() + + cscale_shape = self.quantizer.get_scale_shape(tensor_shape, True) + columnwise_scale_inv = self.columnwise_scale_inv[cscale_start:cscale_end].view( + cscale_shape + ) + + if self.quantizer.internal: + mxfp8_tensor_class = MXFP8TensorStorage + else: + mxfp8_tensor_class = MXFP8Tensor + tensor = mxfp8_tensor_class( + shape=tensor_shape, + dtype=self.dtype, + rowwise_data=rowwise_data, + rowwise_scale_inv=rowwise_scale_inv, + columnwise_data=columnwise_data, + columnwise_scale_inv=columnwise_scale_inv, + fp8_dtype=self.quantizer.dtype, + quantizer=self.quantizer, + with_gemm_swizzled_scales=self.quantizer.optimize_for_gemm, + ) + result.append(tensor) + + # Delayed scaling or current scaling (both use Float8TensorStorage) + elif recipe.delayed() or recipe.float8_current_scaling(): + # Scale inverse - one per tensor + scale_inv = None + if self.scale_inv is not None: + scale_inv = self.scale_inv[i : i + 1] + + if self.quantizer.internal: + float8_tensor_class = Float8TensorStorage + else: + float8_tensor_class = Float8Tensor + + tensor = float8_tensor_class( + shape=tensor_shape, + dtype=self.dtype, + data=rowwise_data, + fp8_scale_inv=scale_inv, + fp8_dtype=self.quantizer.dtype, + quantizer=self.quantizer, + data_transpose=columnwise_data, + ) + result.append(tensor) + + # Float8 block scaling + elif recipe.float8_block_scaling(): + # Extract scale_inv data + rowwise_scale_inv = None + columnwise_scale_inv = None + + if self.scale_inv is not None and self.scale_inv_offsets is not None: + scale_start = self.scale_inv_offsets[i] + if i < self.num_tensors - 1: + scale_end = self.scale_inv_offsets[i + 1] + else: + scale_end = self.scale_inv.numel() + + # Get scale shape from quantizer + scale_shape = self.quantizer.get_scale_shape(tensor_shape, False) + rowwise_scale_inv = self.scale_inv[scale_start:scale_end].view(scale_shape) + + if ( + self.columnwise_scale_inv is not None + and self.columnwise_scale_inv_offsets is not None + ): + cscale_start = self.columnwise_scale_inv_offsets[i] + if i < self.num_tensors - 1: + cscale_end = self.columnwise_scale_inv_offsets[i + 1] + else: + cscale_end = self.columnwise_scale_inv.numel() + + # Get columnwise scale shape from quantizer + cscale_shape = self.quantizer.get_scale_shape(tensor_shape, True) + columnwise_scale_inv = self.columnwise_scale_inv[cscale_start:cscale_end].view( + cscale_shape + ) + + # Compute is_2D_scaled and data_format from quantizer attributes + is_2D_scaled = self.quantizer.block_scaling_dim == 2 + + if self.quantizer.internal: + float8_blockwise_q_tensor_class = Float8BlockwiseQTensorStorage + else: + float8_blockwise_q_tensor_class = Float8BlockwiseQTensor + + tensor = float8_blockwise_q_tensor_class( + shape=tensor_shape, + dtype=self.dtype, + rowwise_data=rowwise_data, + rowwise_scale_inv=rowwise_scale_inv, + columnwise_data=columnwise_data, + columnwise_scale_inv=columnwise_scale_inv, + fp8_dtype=self.quantizer.dtype, + quantizer=self.quantizer, + is_2D_scaled=is_2D_scaled, + ) + result.append(tensor) + + # NVFP4 format + elif recipe.nvfp4(): + # Extract scale_inv data + rowwise_scale_inv = None + columnwise_scale_inv = None + amax_rowwise = None + amax_columnwise = None + + if self.scale_inv is not None and self.scale_inv_offsets is not None: + scale_start = self.scale_inv_offsets[i] + if i < self.num_tensors - 1: + scale_end = self.scale_inv_offsets[i + 1] + else: + scale_end = self.scale_inv.numel() + + # Get scale shape from quantizer + scale_shape = self.quantizer.get_scale_shape(tensor_shape, False) + rowwise_scale_inv = self.scale_inv[scale_start:scale_end].view(scale_shape) + + if ( + self.columnwise_scale_inv is not None + and self.columnwise_scale_inv_offsets is not None + ): + cscale_start = self.columnwise_scale_inv_offsets[i] + if i < self.num_tensors - 1: + cscale_end = self.columnwise_scale_inv_offsets[i + 1] + else: + cscale_end = self.columnwise_scale_inv.numel() + + # Get columnwise scale shape from quantizer + cscale_shape = self.quantizer.get_scale_shape(tensor_shape, True) + columnwise_scale_inv = self.columnwise_scale_inv[cscale_start:cscale_end].view( + cscale_shape + ) + + # Extract amax - one per tensor + if self.amax is not None: + amax_rowwise = self.amax[i : i + 1] + + if self.columnwise_amax is not None: + amax_columnwise = self.columnwise_amax[i : i + 1] + + if self.quantizer.internal: + nvfp4_tensor_class = NVFP4TensorStorage + else: + nvfp4_tensor_class = NVFP4Tensor + + tensor = nvfp4_tensor_class( + shape=tensor_shape, + dtype=self.dtype, + rowwise_data=rowwise_data, + rowwise_scale_inv=rowwise_scale_inv, + columnwise_data=columnwise_data, + columnwise_scale_inv=columnwise_scale_inv, + amax_rowwise=amax_rowwise, + amax_columnwise=amax_columnwise, + fp4_dtype=self.quantizer.dtype, + quantizer=self.quantizer, + with_gemm_swizzled_scales=self.quantizer.optimize_for_gemm, + ) + result.append(tensor) + + else: + raise ValueError(f"Unsupported quantization recipe: {recipe}") + + return result + + @staticmethod + def create_and_quantize( + tensors: int, + quantizer: None | Quantizer, + *, + device: Optional[torch.device] = None, + dtype: Optional[torch.dtype] = None, + noop_flag: Optional[torch.Tensor] = None, + ) -> Tuple[QuantizedTensorStorage, ...]: + """ + Quantize given tensors into quantized tensors with underlying + storage allocated in a GroupedTensor. + """ + + grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes( + num_tensors=len(tensors), + shape=[t.shape for t in tensors], + quantizer=quantizer, + device=device, + dtype=dtype, + ) + + grouped_tensor.quantize(tensors, noop_flag=noop_flag) + + return grouped_tensor + + def quantize( + self, + tensors: List[torch.Tensor], + noop_flag: Optional[torch.Tensor] = None, + ) -> Tuple[QuantizedTensorStorage, ...]: + """ + Quantize the GroupedTensor inplace. + """ + + quantized_tensors = self.split_into_quantized_tensors() + for i in range(self.num_tensors): + self.quantizer.update_quantized(tensors[i], quantized_tensors[i], noop_flag=noop_flag) + return quantized_tensors diff --git a/transformer_engine/pytorch/tensor/storage/mxfp8_tensor_storage.py b/transformer_engine/pytorch/tensor/storage/mxfp8_tensor_storage.py index 1951731c75..5c8510488f 100644 --- a/transformer_engine/pytorch/tensor/storage/mxfp8_tensor_storage.py +++ b/transformer_engine/pytorch/tensor/storage/mxfp8_tensor_storage.py @@ -111,6 +111,24 @@ def clear(self): if t is not None: t.data = _empty_tensor() + def copy_from_storage(self, src: QuantizedTensorStorage) -> None: + """Copy data buffers from another MXFP8TensorStorage.""" + if not isinstance(src, MXFP8TensorStorage): + raise TypeError("copy_from_storage expects MXFP8TensorStorage") + if self._fp8_dtype != src._fp8_dtype: + raise RuntimeError("FP8 dtype mismatch in copy_from_storage") + if self._with_gemm_swizzled_scales != src._with_gemm_swizzled_scales: + raise RuntimeError("Scale layout mismatch in copy_from_storage") + + def _copy_optional(dst: Optional[torch.Tensor], src_tensor: Optional[torch.Tensor]): + if dst is not None and src_tensor is not None: + dst.copy_(src_tensor) + + _copy_optional(self._rowwise_data, src._rowwise_data) + _copy_optional(self._columnwise_data, src._columnwise_data) + _copy_optional(self._rowwise_scale_inv, src._rowwise_scale_inv) + _copy_optional(self._columnwise_scale_inv, src._columnwise_scale_inv) + def get_metadata(self) -> Dict[str, Any]: """Get this tensor's metadata.""" return { diff --git a/transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py b/transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py index b064d711ce..8be23d0c19 100644 --- a/transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py +++ b/transformer_engine/pytorch/tensor/storage/nvfp4_tensor_storage.py @@ -136,6 +136,26 @@ def clear(self): if t is not None: t.data = _empty_tensor() + def copy_from_storage(self, src: QuantizedTensorStorage) -> None: + """Copy data buffers from another NVFP4TensorStorage.""" + if not isinstance(src, NVFP4TensorStorage): + raise TypeError("copy_from_storage expects NVFP4TensorStorage") + if self._fp4_dtype != src._fp4_dtype: + raise RuntimeError("FP4 dtype mismatch in copy_from_storage") + if self._with_gemm_swizzled_scales != src._with_gemm_swizzled_scales: + raise RuntimeError("Scale layout mismatch in copy_from_storage") + + def _copy_optional(dst: Optional[torch.Tensor], src_tensor: Optional[torch.Tensor]): + if dst is not None and src_tensor is not None: + dst.copy_(src_tensor) + + _copy_optional(self._rowwise_data, src._rowwise_data) + _copy_optional(self._columnwise_data, src._columnwise_data) + _copy_optional(self._rowwise_scale_inv, src._rowwise_scale_inv) + _copy_optional(self._columnwise_scale_inv, src._columnwise_scale_inv) + _copy_optional(self._amax_rowwise, src._amax_rowwise) + _copy_optional(self._amax_columnwise, src._amax_columnwise) + def get_metadata(self) -> Dict[str, Any]: """Get this tensor's metadata.""" return {