Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
5e39835
Python GroupedTensor and contiguous weights for GroupedLinear
ksivaman Jan 15, 2026
66e7d7f
Merge branch 'main' into grouped_tensor_python
ksivaman Jan 15, 2026
40c619e
Graph safe C API for grouped RHT, needs testing
ksivaman Jan 16, 2026
cf61339
Merge branch 'main' into grouped_tensor_python
ksivaman Jan 16, 2026
759e7bb
C++ utils, untested
ksivaman Jan 16, 2026
1d09c2a
Merge branch 'main' into grouped_tensor_python
vthumbe1503 Jan 23, 2026
e1b65ac
Merge branch 'main' into grouped_tensor_python
ksivaman Feb 2, 2026
3ba639e
Pytorch Binding for GroupedTensor APIs (#13)
vthumbe1503 Feb 4, 2026
ebf2194
Merge branch 'main' into grouped_tensor_python
ksivaman Feb 4, 2026
4337520
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 4, 2026
5ab30f5
Fix make grouped tensor api
ksivaman Feb 5, 2026
05dab12
Merge branch 'main' into grouped_tensor_python
ksivaman Feb 6, 2026
68ce836
Fixes to tests
ksivaman Feb 6, 2026
3e7859c
PyTorch-Python GroupedTensor
ksivaman Feb 6, 2026
53c38ec
Merge branch 'main' into grouped_tensor_python
ksivaman Feb 8, 2026
d57651d
Fix test
ksivaman Feb 9, 2026
bd41fd0
All tests pass
ksivaman Feb 9, 2026
351b74d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 9, 2026
fd8ce0f
Update transformer_engine/pytorch/tensor/storage/grouped_tensor.py
ksivaman Feb 9, 2026
24cfd8c
Remove mxfp8 gq test
ksivaman Feb 9, 2026
97a1f33
C++ PyTorch GroupedTensor changes WIP
ksivaman Feb 9, 2026
82f7ebe
Merge branch 'main' into pytorch_python_grouped_tensor
ksivaman Feb 10, 2026
e1788b3
Compiles
ksivaman Feb 10, 2026
9022383
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 10, 2026
11095a9
Fix runtime failure for test
ksivaman Feb 10, 2026
373d9e3
Fix IMA in mxfp8 GQ
ksivaman Feb 10, 2026
1601960
Add CG test for grouped_quantize
ksivaman Feb 10, 2026
bd57000
Fix recipe tests and FP8 weights
ksivaman Feb 10, 2026
91ab416
Fix recipe tests and FP8 weights
ksivaman Feb 10, 2026
52ab0ed
Merge branch 'main' into pytorch_python_grouped_tensor
ksivaman Feb 10, 2026
a5de7a5
Fix device test
ksivaman Feb 11, 2026
77fa728
Disable grouped weights for unsupported recipes
ksivaman Feb 11, 2026
9009f75
Merge branch 'pytorch_python_grouped_tensor' into grouped_tensor_python
ksivaman Feb 11, 2026
bea794f
Merge branch 'main' into grouped_tensor_python
ksivaman Feb 11, 2026
6b0c420
Merge branch 'main' into grouped_tensor_python
ksivaman Feb 11, 2026
e3278dd
Merge branch 'main' into grouped_tensor_python
ksivaman Feb 12, 2026
864c484
Integrate NVFP4 Graph Safe Group Quantize (#14)
zhongbozhu Feb 14, 2026
4ee0339
improve mxfp8 unit test
zhongbozhu Feb 17, 2026
9f5f24c
pre-swizzle nvfp4 mxfp8 for MoE
zhongbozhu Feb 18, 2026
22f8a5b
avoid having nvte_get_grouped_tensor_param_v2
zhongbozhu Feb 18, 2026
63e1563
more tests
zhongbozhu Feb 18, 2026
4d66324
fix group quantize mxfp8 kernel
zhongbozhu Feb 20, 2026
621fb0e
Relaxed restriction for the last dim to be a multiple of 128
Oleg-Goncharov Feb 24, 2026
439c933
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 24, 2026
3de9850
Merge branch 'main' into grouped_tensor_python
ksivaman Feb 25, 2026
5bf0cf9
Merge branch 'main' into grouped_tensor_python
ksivaman Feb 26, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
429 changes: 429 additions & 0 deletions tests/pytorch/test_grouped_tensor.py

Large diffs are not rendered by default.

131 changes: 128 additions & 3 deletions tests/pytorch/test_sanity.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
#
# See LICENSE for license information.

from typing import Optional
from typing import Optional, List

import torch
import pytest
Expand Down Expand Up @@ -137,6 +137,117 @@ def reset_global_fp8_state():
FP8GlobalStateManager.reset()


def check_grouped_tensor_pointers_helper(tensors, num_elems_in_byte=1, tensor_name="tensor"):
"""
Verify that tensors are stored in contiguous memory.

Args:
tensors: List or iterable of tensors to check
num_elems_in_byte: Number of elements packed per byte (1 for normal, 2 for NVFP4)
tensor_name: Name to use in error messages
"""
tensor_list = list(tensors)
if len(tensor_list) < 2:
return # Nothing to check

for i in range(1, len(tensor_list)):
prev_tensor = tensor_list[i - 1]
curr_tensor = tensor_list[i]

# Calculate expected offset based on previous tensor size
prev_numel = prev_tensor.numel()
expected_offset = (prev_numel // num_elems_in_byte) * prev_tensor.element_size()

# Verify current tensor's data pointer is correctly offset
expected_ptr = prev_tensor.data_ptr() + expected_offset
actual_ptr = curr_tensor.data_ptr()

assert (
actual_ptr == expected_ptr
), f"{tensor_name} {i} data pointer mismatch: expected {expected_ptr}, got {actual_ptr}"


def check_grouped_tensor_pointers(
weights: List[torch.Tensor], fp8_recipe: Optional[recipe.Recipe] = None
):
"""
Verify that the pointers of the weights are in contiguous memory for GroupedTensor.
TODO(ksivaman): This check can be made way more efficient but for now leaving the brute force approach.
"""

num_elems_in_a_data_byte = 1 if fp8_recipe is None else 2 if fp8_recipe.nvfp4() else 1

# Check data.
if hasattr(weights[0], "_data") and weights[0]._data is not None:
data_tensors = [w._data for w in weights]
check_grouped_tensor_pointers_helper(data_tensors, num_elems_in_byte=1, tensor_name="data")

# Check transpose.
if hasattr(weights[0], "_transpose") and weights[0]._transpose is not None:
transpose_tensors = [w._transpose for w in weights]
check_grouped_tensor_pointers_helper(
transpose_tensors, num_elems_in_byte=1, tensor_name="transpose"
)

# Check scale_inv.
if hasattr(weights[0], "_scale_inv") and weights[0]._scale_inv is not None:
scale_inv_tensors = [w._scale_inv for w in weights]
check_grouped_tensor_pointers_helper(
scale_inv_tensors, num_elems_in_byte=1, tensor_name="scale_inv"
)

# Check rowwise scale_inv.
if hasattr(weights[0], "_rowwise_scale_inv") and weights[0]._rowwise_scale_inv is not None:
scale_inv_tensors = [w._rowwise_scale_inv for w in weights]
check_grouped_tensor_pointers_helper(
scale_inv_tensors, num_elems_in_byte=1, tensor_name="rowwise_scale_inv"
)

# Check columnwise scale_inv.
if (
hasattr(weights[0], "_columnwise_scale_inv")
and weights[0]._columnwise_scale_inv is not None
):
columnwise_scale_inv_tensors = [w._columnwise_scale_inv for w in weights]
check_grouped_tensor_pointers_helper(
columnwise_scale_inv_tensors,
num_elems_in_byte=1,
tensor_name="columnwise scale_inv",
)

# Check rowwise amax.
if hasattr(weights[0], "_rowwise_amax") and weights[0]._rowwise_amax is not None:
rowwise_amax_tensors = [w._rowwise_amax for w in weights]
check_grouped_tensor_pointers_helper(
rowwise_amax_tensors, num_elems_in_byte=1, tensor_name="rowwise amax"
)

# Check columnwise amax.
if hasattr(weights[0], "_columnwise_amax") and weights[0]._columnwise_amax is not None:
columnwise_amax_tensors = [w._columnwise_amax for w in weights]
check_grouped_tensor_pointers_helper(
columnwise_amax_tensors, num_elems_in_byte=1, tensor_name="columnwise amax"
)

# Check rowwise data.
if hasattr(weights[0], "_rowwise_data") and weights[0]._rowwise_data is not None:
rowwise_data_tensors = [w._rowwise_data for w in weights]
check_grouped_tensor_pointers_helper(
rowwise_data_tensors,
num_elems_in_byte=num_elems_in_a_data_byte,
tensor_name="rowwise data",
)

# Check columnwise data.
if hasattr(weights[0], "_columnwise_data") and weights[0]._columnwise_data is not None:
columnwise_data_tensors = [w._columnwise_data for w in weights]
check_grouped_tensor_pointers_helper(
columnwise_data_tensors,
num_elems_in_byte=num_elems_in_a_data_byte,
tensor_name="columnwise data",
)


def _test_sanity_e2e_amp(block, dtype, config, fp8_recipe, skip_wgrad):
te_inp_hidden_states = torch.randn(
(config.max_seqlen_q, config.batch_size, config.hidden_size),
Expand Down Expand Up @@ -495,9 +606,17 @@ def test_sanity_grouped_linear(
use_fp8 = fp8_recipe is not None
with quantized_model_init(enabled=use_fp8 and fp8_model_params, recipe=fp8_recipe):
te_grouped_linear = GroupedLinear(
num_gemms, config.hidden_size, ffn_hidden_size, bias=use_bias, params_dtype=dtype
num_gemms,
config.hidden_size,
ffn_hidden_size,
bias=use_bias,
params_dtype=dtype,
).cuda()

# Verify that weights are stored in contiguous GroupedTensor storage.
weights = [getattr(te_grouped_linear, f"weight{i}") for i in range(num_gemms)]
check_grouped_tensor_pointers(weights, fp8_recipe)

inp_hidden_states = torch.randn(
num_tokens, config.hidden_size, dtype=dtype, requires_grad=True
).cuda()
Expand Down Expand Up @@ -956,7 +1075,13 @@ def test_replace_raw_data_for_float8tensor():
random_bf16_data = torch.randn(fp8_tensor.shape, dtype=torch.bfloat16, device="cuda")
fp8_quantizer.update_quantized(random_bf16_data, fp8_tensor)

attrs_to_check = ["_quantizer", "_fp8_dtype", "_scale_inv", "_transpose", "_transpose_invalid"]
attrs_to_check = [
"_quantizer",
"_fp8_dtype",
"_scale_inv",
"_transpose",
"_transpose_invalid",
]
attrs = {}
for attr in attrs_to_check:
attrs[attr] = getattr(fp8_tensor, attr)
Expand Down
2 changes: 2 additions & 0 deletions transformer_engine/common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -172,10 +172,12 @@ list(APPEND transformer_engine_cuda_arch_specific_sources
cast/cast.cu
gemm/cutlass_grouped_gemm.cu
hadamard_transform/group_hadamard_transform.cu
hadamard_transform/graph_safe_group_hadamard_transform.cu
hadamard_transform/hadamard_transform.cu
hadamard_transform/hadamard_transform_cast_fusion.cu
hadamard_transform/group_hadamard_transform_cast_fusion.cu
hadamard_transform/group_row_cast_col_hadamard_transform_cast_fusion.cu
hadamard_transform/graph_safe_group_row_cast_col_hadamard_transform_cast_fusion.cu
multi_tensor/compute_scale.cu
recipe/mxfp8_scaling.cu
transpose/quantize_transpose_square_blockwise.cu
Expand Down
Loading
Loading