Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
5e39835
Python GroupedTensor and contiguous weights for GroupedLinear
ksivaman Jan 15, 2026
66e7d7f
Merge branch 'main' into grouped_tensor_python
ksivaman Jan 15, 2026
40c619e
Graph safe C API for grouped RHT, needs testing
ksivaman Jan 16, 2026
cf61339
Merge branch 'main' into grouped_tensor_python
ksivaman Jan 16, 2026
759e7bb
C++ utils, untested
ksivaman Jan 16, 2026
1d09c2a
Merge branch 'main' into grouped_tensor_python
vthumbe1503 Jan 23, 2026
e1b65ac
Merge branch 'main' into grouped_tensor_python
ksivaman Feb 2, 2026
3ba639e
Pytorch Binding for GroupedTensor APIs (#13)
vthumbe1503 Feb 4, 2026
ebf2194
Merge branch 'main' into grouped_tensor_python
ksivaman Feb 4, 2026
4337520
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 4, 2026
5ab30f5
Fix make grouped tensor api
ksivaman Feb 5, 2026
05dab12
Merge branch 'main' into grouped_tensor_python
ksivaman Feb 6, 2026
68ce836
Fixes to tests
ksivaman Feb 6, 2026
3e7859c
PyTorch-Python GroupedTensor
ksivaman Feb 6, 2026
53c38ec
Merge branch 'main' into grouped_tensor_python
ksivaman Feb 8, 2026
d57651d
Fix test
ksivaman Feb 9, 2026
bd41fd0
All tests pass
ksivaman Feb 9, 2026
351b74d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 9, 2026
fd8ce0f
Update transformer_engine/pytorch/tensor/storage/grouped_tensor.py
ksivaman Feb 9, 2026
24cfd8c
Remove mxfp8 gq test
ksivaman Feb 9, 2026
97a1f33
C++ PyTorch GroupedTensor changes WIP
ksivaman Feb 9, 2026
82f7ebe
Merge branch 'main' into pytorch_python_grouped_tensor
ksivaman Feb 10, 2026
e1788b3
Compiles
ksivaman Feb 10, 2026
9022383
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 10, 2026
11095a9
Fix runtime failure for test
ksivaman Feb 10, 2026
373d9e3
Fix IMA in mxfp8 GQ
ksivaman Feb 10, 2026
1601960
Add CG test for grouped_quantize
ksivaman Feb 10, 2026
bd57000
Fix recipe tests and FP8 weights
ksivaman Feb 10, 2026
91ab416
Fix recipe tests and FP8 weights
ksivaman Feb 10, 2026
52ab0ed
Merge branch 'main' into pytorch_python_grouped_tensor
ksivaman Feb 10, 2026
a5de7a5
Fix device test
ksivaman Feb 11, 2026
77fa728
Disable grouped weights for unsupported recipes
ksivaman Feb 11, 2026
9009f75
Merge branch 'pytorch_python_grouped_tensor' into grouped_tensor_python
ksivaman Feb 11, 2026
bea794f
Merge branch 'main' into grouped_tensor_python
ksivaman Feb 11, 2026
6b0c420
Merge branch 'main' into grouped_tensor_python
ksivaman Feb 11, 2026
e3278dd
Merge branch 'main' into grouped_tensor_python
ksivaman Feb 12, 2026
864c484
Integrate NVFP4 Graph Safe Group Quantize (#14)
zhongbozhu Feb 14, 2026
4ee0339
improve mxfp8 unit test
zhongbozhu Feb 17, 2026
9f5f24c
pre-swizzle nvfp4 mxfp8 for MoE
zhongbozhu Feb 18, 2026
22f8a5b
avoid having nvte_get_grouped_tensor_param_v2
zhongbozhu Feb 18, 2026
63e1563
more tests
zhongbozhu Feb 18, 2026
4d66324
fix group quantize mxfp8 kernel
zhongbozhu Feb 20, 2026
621fb0e
Relaxed restriction for the last dim to be a multiple of 128
Oleg-Goncharov Feb 24, 2026
439c933
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 24, 2026
3de9850
Merge branch 'main' into grouped_tensor_python
ksivaman Feb 25, 2026
5bf0cf9
Merge branch 'main' into grouped_tensor_python
ksivaman Feb 26, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
429 changes: 429 additions & 0 deletions tests/pytorch/test_grouped_tensor.py

Large diffs are not rendered by default.

131 changes: 128 additions & 3 deletions tests/pytorch/test_sanity.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
#
# See LICENSE for license information.

from typing import Optional
from typing import Optional, List

import torch
import pytest
Expand Down Expand Up @@ -137,6 +137,117 @@ def reset_global_fp8_state():
FP8GlobalStateManager.reset()


def check_grouped_tensor_pointers_helper(tensors, num_elems_in_byte=1, tensor_name="tensor"):
"""
Verify that tensors are stored in contiguous memory.

Args:
tensors: List or iterable of tensors to check
num_elems_in_byte: Number of elements packed per byte (1 for normal, 2 for NVFP4)
tensor_name: Name to use in error messages
"""
tensor_list = list(tensors)
if len(tensor_list) < 2:
return # Nothing to check

for i in range(1, len(tensor_list)):
prev_tensor = tensor_list[i - 1]
curr_tensor = tensor_list[i]

# Calculate expected offset based on previous tensor size
prev_numel = prev_tensor.numel()
expected_offset = (prev_numel // num_elems_in_byte) * prev_tensor.element_size()

# Verify current tensor's data pointer is correctly offset
expected_ptr = prev_tensor.data_ptr() + expected_offset
actual_ptr = curr_tensor.data_ptr()

assert (
actual_ptr == expected_ptr
), f"{tensor_name} {i} data pointer mismatch: expected {expected_ptr}, got {actual_ptr}"


def check_grouped_tensor_pointers(
weights: List[torch.Tensor], fp8_recipe: Optional[recipe.Recipe] = None
):
"""
Verify that the pointers of the weights are in contiguous memory for GroupedTensor.
TODO(ksivaman): This check can be made way more efficient but for now leaving the brute force approach.
"""

num_elems_in_a_data_byte = 1 if fp8_recipe is None else 2 if fp8_recipe.nvfp4() else 1

# Check data.
if hasattr(weights[0], "_data") and weights[0]._data is not None:
data_tensors = [w._data for w in weights]
check_grouped_tensor_pointers_helper(data_tensors, num_elems_in_byte=1, tensor_name="data")

# Check transpose.
if hasattr(weights[0], "_transpose") and weights[0]._transpose is not None:
transpose_tensors = [w._transpose for w in weights]
check_grouped_tensor_pointers_helper(
transpose_tensors, num_elems_in_byte=1, tensor_name="transpose"
)

# Check scale_inv.
if hasattr(weights[0], "_scale_inv") and weights[0]._scale_inv is not None:
scale_inv_tensors = [w._scale_inv for w in weights]
check_grouped_tensor_pointers_helper(
scale_inv_tensors, num_elems_in_byte=1, tensor_name="scale_inv"
)

# Check rowwise scale_inv.
if hasattr(weights[0], "_rowwise_scale_inv") and weights[0]._rowwise_scale_inv is not None:
scale_inv_tensors = [w._rowwise_scale_inv for w in weights]
check_grouped_tensor_pointers_helper(
scale_inv_tensors, num_elems_in_byte=1, tensor_name="rowwise_scale_inv"
)

# Check columnwise scale_inv.
if (
hasattr(weights[0], "_columnwise_scale_inv")
and weights[0]._columnwise_scale_inv is not None
):
columnwise_scale_inv_tensors = [w._columnwise_scale_inv for w in weights]
check_grouped_tensor_pointers_helper(
columnwise_scale_inv_tensors,
num_elems_in_byte=1,
tensor_name="columnwise scale_inv",
)

# Check rowwise amax.
if hasattr(weights[0], "_rowwise_amax") and weights[0]._rowwise_amax is not None:
rowwise_amax_tensors = [w._rowwise_amax for w in weights]
check_grouped_tensor_pointers_helper(
rowwise_amax_tensors, num_elems_in_byte=1, tensor_name="rowwise amax"
)

# Check columnwise amax.
if hasattr(weights[0], "_columnwise_amax") and weights[0]._columnwise_amax is not None:
columnwise_amax_tensors = [w._columnwise_amax for w in weights]
check_grouped_tensor_pointers_helper(
columnwise_amax_tensors, num_elems_in_byte=1, tensor_name="columnwise amax"
)

# Check rowwise data.
if hasattr(weights[0], "_rowwise_data") and weights[0]._rowwise_data is not None:
rowwise_data_tensors = [w._rowwise_data for w in weights]
check_grouped_tensor_pointers_helper(
rowwise_data_tensors,
num_elems_in_byte=num_elems_in_a_data_byte,
tensor_name="rowwise data",
)

# Check columnwise data.
if hasattr(weights[0], "_columnwise_data") and weights[0]._columnwise_data is not None:
columnwise_data_tensors = [w._columnwise_data for w in weights]
check_grouped_tensor_pointers_helper(
columnwise_data_tensors,
num_elems_in_byte=num_elems_in_a_data_byte,
tensor_name="columnwise data",
)


def _test_sanity_e2e_amp(block, dtype, config, fp8_recipe, skip_wgrad):
te_inp_hidden_states = torch.randn(
(config.max_seqlen_q, config.batch_size, config.hidden_size),
Expand Down Expand Up @@ -495,9 +606,17 @@ def test_sanity_grouped_linear(
use_fp8 = fp8_recipe is not None
with quantized_model_init(enabled=use_fp8 and fp8_model_params, recipe=fp8_recipe):
te_grouped_linear = GroupedLinear(
num_gemms, config.hidden_size, ffn_hidden_size, bias=use_bias, params_dtype=dtype
num_gemms,
config.hidden_size,
ffn_hidden_size,
bias=use_bias,
params_dtype=dtype,
).cuda()

# Verify that weights are stored in contiguous GroupedTensor storage.
weights = [getattr(te_grouped_linear, f"weight{i}") for i in range(num_gemms)]
check_grouped_tensor_pointers(weights, fp8_recipe)

inp_hidden_states = torch.randn(
num_tokens, config.hidden_size, dtype=dtype, requires_grad=True
).cuda()
Expand Down Expand Up @@ -956,7 +1075,13 @@ def test_replace_raw_data_for_float8tensor():
random_bf16_data = torch.randn(fp8_tensor.shape, dtype=torch.bfloat16, device="cuda")
fp8_quantizer.update_quantized(random_bf16_data, fp8_tensor)

attrs_to_check = ["_quantizer", "_fp8_dtype", "_scale_inv", "_transpose", "_transpose_invalid"]
attrs_to_check = [
"_quantizer",
"_fp8_dtype",
"_scale_inv",
"_transpose",
"_transpose_invalid",
]
attrs = {}
for attr in attrs_to_check:
attrs[attr] = getattr(fp8_tensor, attr)
Expand Down
35 changes: 21 additions & 14 deletions transformer_engine/common/recipe/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,33 +88,40 @@ class Recipe:
Base recipe class.
"""

def nvfp4(self):
@classmethod
def nvfp4(cls):
"""Whether the given recipe is NVFP4 1D block scaling."""
return isinstance(self, NVFP4BlockScaling)
return issubclass(cls, NVFP4BlockScaling)

def mxfp8(self):
@classmethod
def mxfp8(cls):
"""Whether the given recipe is MXFP8 block scaling."""
return isinstance(self, MXFP8BlockScaling)
return issubclass(cls, MXFP8BlockScaling)

def delayed(self):
@classmethod
def delayed(cls):
"""Whether the given recipe is delayed scaling."""
return isinstance(self, DelayedScaling)
return issubclass(cls, DelayedScaling)

def float8_current_scaling(self):
@classmethod
def float8_current_scaling(cls):
"""Whether the given recipe is (per-tensor) current scaling."""
return isinstance(self, Float8CurrentScaling)
return issubclass(cls, Float8CurrentScaling)

def float8_per_tensor_scaling(self):
@classmethod
def float8_per_tensor_scaling(cls):
"""Whether the given recipe is per-tensor scaling."""
return isinstance(self, (DelayedScaling, Float8CurrentScaling))
return issubclass(cls, (DelayedScaling, Float8CurrentScaling))

def float8_block_scaling(self):
@classmethod
def float8_block_scaling(cls):
"""Whether the given recipe is float8 blockwise scaling."""
return isinstance(self, Float8BlockScaling)
return issubclass(cls, Float8BlockScaling)

def custom(self):
@classmethod
def custom(cls):
"""Whether the given recipe is custom."""
return isinstance(self, CustomRecipe)
return issubclass(cls, CustomRecipe)


@dataclass()
Expand Down
78 changes: 69 additions & 9 deletions transformer_engine/pytorch/module/grouped_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import transformer_engine_torch as tex

from transformer_engine.common.recipe import Recipe
from transformer_engine.pytorch.tensor.storage.grouped_tensor import GroupedTensor
from .base import (
get_dummy_wgrad,
TransformerEngineBaseModule,
Expand Down Expand Up @@ -147,7 +148,10 @@ def forward(
# tensors (like scales), but bulk allocation shares storage across all tensors,
# so if scales can't be offloaded, nothing in the group can be offloaded.
inputmats = tex.split_quantize(
inp_view, m_splits, input_quantizers, disable_bulk_allocation=cpu_offloading
inp_view,
m_splits,
input_quantizers,
disable_bulk_allocation=cpu_offloading,
)
elif debug:
inputmats = DebugQuantizer.multi_tensor_quantize(
Expand Down Expand Up @@ -365,7 +369,10 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None],
for i in range(ctx.num_gemms):
grad_biases[i] = grad_output_mats[i].sum(dim=0)
grad_output = DebugQuantizer.multi_tensor_quantize(
grad_output_view, ctx.grad_output_quantizers, ctx.m_splits, ctx.activation_dtype
grad_output_view,
ctx.grad_output_quantizers,
ctx.m_splits,
ctx.activation_dtype,
)
else:
# Only split grad output. Grad bias is fused with
Expand Down Expand Up @@ -436,7 +443,8 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None],
if ctx.input_quantizers[0] is not None:
for input_quantizer in ctx.input_quantizers:
if isinstance(
input_quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)
input_quantizer,
(Float8Quantizer, Float8CurrentScalingQuantizer),
):
input_quantizer.set_usage(rowwise=True, columnwise=True)
else:
Expand All @@ -446,7 +454,10 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None],
inputmats = tex.split_quantize(inp_view, ctx.m_splits, ctx.input_quantizers)
elif ctx.debug:
inputmats = DebugQuantizer.multi_tensor_quantize(
inp_view, ctx.input_quantizers, ctx.m_splits, ctx.activation_dtype
inp_view,
ctx.input_quantizers,
ctx.m_splits,
ctx.activation_dtype,
)
else:
inputmats = torch.split(
Expand Down Expand Up @@ -616,7 +627,7 @@ def __init__(
) -> None:
super().__init__()

params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype
self.params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype
self.num_gemms = num_gemms
self.in_features = in_features
self.out_features = out_features
Expand All @@ -631,13 +642,20 @@ def __init__(
assert (
not ub_overlap_rs and not ub_overlap_ag
), "GroupedLinear doesn't support Userbuffer overlap."
self.init_method = init_method
self.get_rng_state_tracker = get_rng_state_tracker
self.rng_tracker_name = rng_tracker_name
self.name = name

self.wgrad_store = WeightGradStore(delay_wgrad_compute)

self._offsets = {"input": 0, "weight": 1, "output": 2, "grad_output": 0, "grad_input": 1}
self._offsets = {
"input": 0,
"weight": 1,
"output": 2,
"grad_output": 0,
"grad_input": 1,
}
self._num_fp8_tensors_per_gemm = {
"fwd": 3,
"bwd": 2,
Expand Down Expand Up @@ -679,7 +697,7 @@ def __init__(
self.out_features,
self.in_features,
device=device,
dtype=params_dtype,
dtype=self.params_dtype,
),
),
init_fn=init_method,
Expand All @@ -695,20 +713,21 @@ def __init__(
torch.empty(
self.out_features,
device=device,
dtype=params_dtype,
dtype=self.params_dtype,
),
),
init_fn=init_method_constant(0.0),
)
else:
bias = torch.Tensor().to(dtype=params_dtype, device=device)
bias = torch.Tensor().to(dtype=self.params_dtype, device=device)
setattr(self, f"bias{i}", bias)

if self.primary_weights_in_fp8:
self.init_fp8_metadata(num_gemms=self.num_gemms)

is_meta = torch.device(device).type == "meta"
self.reset_parameters(defer_init=is_meta)
self.make_grouped_weights(defer_init=is_meta)

if self.wgrad_store.delay_wgrad_compute():
for name, param in self.named_parameters():
Expand All @@ -729,8 +748,49 @@ def set_meta_tensor(self, fwd: bool, recipe: Recipe) -> None:
)
self._customize_quantizers_float8_current_scaling(fwd, recipe)

def make_grouped_weights(self, defer_init=False) -> None:
"""
Convert parameters into a GroupedTensor and re-register them as parameters.
"""

if defer_init:
return

weights = [getattr(self, f"weight{i}") for i in range(self.num_gemms)]
Comment thread
zhongbozhu marked this conversation as resolved.
Outdated
weight_quantizers = self._get_weight_quantizers()

# Create the weight storage.
grouped_weights = GroupedTensor.make_grouped_tensor(
num_tensors=self.num_gemms,
shape=[(self.out_features, self.in_features)] * self.num_gemms,
quantizers=weight_quantizers,
dtype=self.params_dtype,
)

# Copy existing params into storage.
# TODO(ksivamani): Verify correctness of copy for all recipes.
with torch.no_grad():
for i in range(self.num_gemms):
grouped_weights.quantized_tensors[i].copy_(weights[i])
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style: check that the copy operation works correctly for all quantization recipes (FP8, MXFP8, NVFP4, block scaling). the TODO comment on line 771 acknowledges this needs verification.


# Re-register the grouped weights as parameters.
for i in range(self.num_gemms):
self.register_parameter(
f"weight{i}",
torch.nn.Parameter(grouped_weights.quantized_tensors[i]),
init_fn=self.init_method,
get_rng_state_tracker=self.get_rng_state_tracker,
fp8_meta_index=self._offsets["weight"] + i * self._num_fp8_tensors_per_gemm["fwd"],
)

self.set_tensor_parallel_attributes(defer_init=defer_init)

def reset_parameters(self, defer_init=False):
super().reset_parameters(defer_init=defer_init)
self.set_tensor_parallel_attributes(defer_init=defer_init)

def set_tensor_parallel_attributes(self, defer_init=False) -> None:
"""Set attributes needed for TP"""

if not defer_init:
# Set parallelism attributes for linear weights
Expand Down
Loading
Loading