Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
5e39835
Python GroupedTensor and contiguous weights for GroupedLinear
ksivaman Jan 15, 2026
66e7d7f
Merge branch 'main' into grouped_tensor_python
ksivaman Jan 15, 2026
40c619e
Graph safe C API for grouped RHT, needs testing
ksivaman Jan 16, 2026
cf61339
Merge branch 'main' into grouped_tensor_python
ksivaman Jan 16, 2026
759e7bb
C++ utils, untested
ksivaman Jan 16, 2026
1d09c2a
Merge branch 'main' into grouped_tensor_python
vthumbe1503 Jan 23, 2026
e1b65ac
Merge branch 'main' into grouped_tensor_python
ksivaman Feb 2, 2026
3ba639e
Pytorch Binding for GroupedTensor APIs (#13)
vthumbe1503 Feb 4, 2026
ebf2194
Merge branch 'main' into grouped_tensor_python
ksivaman Feb 4, 2026
4337520
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 4, 2026
5ab30f5
Fix make grouped tensor api
ksivaman Feb 5, 2026
05dab12
Merge branch 'main' into grouped_tensor_python
ksivaman Feb 6, 2026
68ce836
Fixes to tests
ksivaman Feb 6, 2026
3e7859c
PyTorch-Python GroupedTensor
ksivaman Feb 6, 2026
53c38ec
Merge branch 'main' into grouped_tensor_python
ksivaman Feb 8, 2026
d57651d
Fix test
ksivaman Feb 9, 2026
bd41fd0
All tests pass
ksivaman Feb 9, 2026
351b74d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 9, 2026
fd8ce0f
Update transformer_engine/pytorch/tensor/storage/grouped_tensor.py
ksivaman Feb 9, 2026
24cfd8c
Remove mxfp8 gq test
ksivaman Feb 9, 2026
97a1f33
C++ PyTorch GroupedTensor changes WIP
ksivaman Feb 9, 2026
82f7ebe
Merge branch 'main' into pytorch_python_grouped_tensor
ksivaman Feb 10, 2026
e1788b3
Compiles
ksivaman Feb 10, 2026
9022383
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 10, 2026
11095a9
Fix runtime failure for test
ksivaman Feb 10, 2026
373d9e3
Fix IMA in mxfp8 GQ
ksivaman Feb 10, 2026
1601960
Add CG test for grouped_quantize
ksivaman Feb 10, 2026
bd57000
Fix recipe tests and FP8 weights
ksivaman Feb 10, 2026
91ab416
Fix recipe tests and FP8 weights
ksivaman Feb 10, 2026
52ab0ed
Merge branch 'main' into pytorch_python_grouped_tensor
ksivaman Feb 10, 2026
a5de7a5
Fix device test
ksivaman Feb 11, 2026
77fa728
Disable grouped weights for unsupported recipes
ksivaman Feb 11, 2026
9009f75
Merge branch 'pytorch_python_grouped_tensor' into grouped_tensor_python
ksivaman Feb 11, 2026
bea794f
Merge branch 'main' into grouped_tensor_python
ksivaman Feb 11, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 102 additions & 9 deletions tests/pytorch/test_grouped_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@


def make_quantizer(quantization: str, num_tensors: int, shape: List[Tuple[int, int]]) -> Quantizer:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't need "num_tensors" as an argument here anymore, I think, because we assume all tensors in the group use the same kind of quantizer.

"""Create quantizers for given quantization scheme"""
"""Create quantizer for given quantization scheme"""

if quantization == "fp8_delayed_scaling":
quantizer = Float8Quantizer(
Expand Down Expand Up @@ -203,12 +203,12 @@ def test_split_into_quantized_tensors_quantized(self, quantization: str) -> None
"""Test split_into_quantized_tensors for quantized tensors"""
num_tensors = 3
shape = [(512, 512) for _ in range(num_tensors)]
quantizers = make_quantizer(quantization, num_tensors, shape)
quantizer = make_quantizer(quantization, num_tensors, shape)

grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes(
num_tensors=num_tensors,
shape=shape,
quantizer=quantizers,
quantizer=quantizer,
device="cuda",
)

Expand Down Expand Up @@ -260,12 +260,12 @@ def test_quantize_inplace(self, quantization: str) -> None:
"""Test that quantize is done in-place for all recipes"""
num_tensors = 3
shape = [(512, 512) for _ in range(num_tensors)]
quantizers = make_quantizer(quantization, num_tensors, shape)
quantizer = make_quantizer(quantization, num_tensors, shape)

grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes(
num_tensors=num_tensors,
shape=shape,
quantizer=quantizers,
quantizer=quantizer,
device="cuda",
)

Expand Down Expand Up @@ -300,12 +300,12 @@ def test_quantize_varying_shapes(self, quantization: str) -> None:
"""Test quantize with varying shapes"""
num_tensors = 3
shape = [(256, 512), (512, 512), (768, 512)]
quantizers = make_quantizer(quantization, num_tensors, shape)
quantizer = make_quantizer(quantization, num_tensors, shape)

grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes(
num_tensors=num_tensors,
shape=shape,
quantizer=quantizers,
quantizer=quantizer,
device="cuda",
)

Expand Down Expand Up @@ -334,15 +334,15 @@ def test_static_quantize_method(self, quantization: str) -> None:
"""Test the static quantize method"""
num_tensors = 3
shape = [(512, 512) for _ in range(num_tensors)]
quantizers = make_quantizer(quantization, num_tensors, shape)
quantizer = make_quantizer(quantization, num_tensors, shape)

# Create input tensors
input_tensors = [torch.randn(s, dtype=torch.float32, device="cuda") for s in shape]

# Use static quantize method
grouped_tensor = GroupedTensor.create_and_quantize(
tensors=input_tensors,
quantizer=quantizers,
quantizer=quantizer,
device="cuda",
)

Expand All @@ -361,6 +361,99 @@ def test_static_quantize_method(self, quantization: str) -> None:
expected_offset = _rowwise_offset_bytes(i * numel, quantization)
assert rowwise_data.data_ptr() == original_data_ptr + expected_offset

@pytest.mark.parametrize(
"shape",
[[(256, 512), (512, 512), (768, 512)], [(512, 512), (512, 512), (512, 512)]],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please add edge cases to the test:

  1. zero tokens at all
  2. 0 tokens in the beginning
  3. 0 tokens in the end
  4. 0 tokens in the middle

)
@pytest.mark.skipif(not mxfp8_available, reason=reason_for_no_mxfp8)
def test_quantize_grouped_mxfp8(self, shape: List[Tuple[int, int]]) -> None:
"""Test grouped quantization for MXFP8 against per-tensor quantization."""
# Test wont pass until the grouped quantization PR from Oleg is merged.
num_tensors = 2
shape = [(512, 1024) for _ in range(num_tensors)]

# Create BF16 input tensors and pack into a 2D tensor
input_tensors = [torch.randn(s, dtype=torch.bfloat16, device="cuda") for s in shape]
quantized_tensors = [
MXFP8Quantizer(fp8_dtype=tex.DType.kFloat8E4M3)(tensor) for tensor in input_tensors
]
grouped_input = torch.cat(input_tensors, dim=0)

# Create MXFP8 output grouped tensor (rowwise only for easier validation)
quantizer = MXFP8Quantizer(fp8_dtype=tex.DType.kFloat8E4M3)
quantizer.set_usage(rowwise=True, columnwise=False)
first_dims = torch.tensor(
[shape[0][0] for _ in range(num_tensors)],
dtype=torch.int64,
device="cuda",
)

# Quantize using grouped API
grouped_output = tex.group_quantize(
grouped_input,
quantizer,
num_tensors,
first_dims,
)
# Build expected output by quantizing each tensor independently
expected_data = []
expected_scale_inv = []
for tensor in input_tensors:
qtensor = quantizer(tensor)
expected_data.append(qtensor._rowwise_data.reshape(-1))
expected_scale_inv.append(qtensor._rowwise_scale_inv.reshape(-1))

expected_data = torch.cat(expected_data)
expected_scale_inv = torch.cat(expected_scale_inv)

assert torch.equal(grouped_output.data, expected_data)
assert torch.equal(grouped_output.scale_inv, expected_scale_inv)

@pytest.mark.skipif(not mxfp8_available, reason=reason_for_no_mxfp8)
def test_group_quantize_cudagraph_capturable(self) -> None:
"""Ensure group_quantize is CUDA graph capturable."""
num_tensors = 2
shape = [(512, 1024) for _ in range(num_tensors)]
input_tensors = [torch.randn(s, dtype=torch.bfloat16, device="cuda") for s in shape]
grouped_input = torch.cat(input_tensors, dim=0)

quantizer = MXFP8Quantizer(fp8_dtype=tex.DType.kFloat8E4M3)
quantizer.set_usage(rowwise=True, columnwise=False)
first_dims = torch.tensor(
[shape[0][0] for _ in range(num_tensors)],
dtype=torch.int64,
device="cuda",
)

torch.cuda.synchronize()
static_input = grouped_input.clone()
static_first_dims = first_dims.clone()

# Warmup to initialize kernels and allocator state
_ = tex.group_quantize(static_input, quantizer, num_tensors, static_first_dims)
torch.cuda.synchronize()

graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph):
static_output = tex.group_quantize(
static_input,
quantizer,
num_tensors,
static_first_dims,
)

fresh_input = torch.cat(
[torch.randn(s, dtype=torch.bfloat16, device="cuda") for s in shape],
dim=0,
)
static_input.copy_(fresh_input)
graph.replay()
torch.cuda.synchronize()

expected = tex.group_quantize(static_input, quantizer, num_tensors, static_first_dims)
assert torch.equal(static_output.data, expected.data)
assert torch.equal(static_output.scale_inv, expected.scale_inv)

def test_clear(self) -> None:
"""Test clear method"""
num_tensors = 3
Expand Down
2 changes: 2 additions & 0 deletions transformer_engine/common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -173,10 +173,12 @@ list(APPEND transformer_engine_cuda_arch_specific_sources
cast/cast.cu
gemm/cutlass_grouped_gemm.cu
hadamard_transform/group_hadamard_transform.cu
hadamard_transform/graph_safe_group_hadamard_transform.cu
hadamard_transform/hadamard_transform.cu
hadamard_transform/hadamard_transform_cast_fusion.cu
hadamard_transform/group_hadamard_transform_cast_fusion.cu
hadamard_transform/group_row_cast_col_hadamard_transform_cast_fusion.cu
hadamard_transform/graph_safe_group_row_cast_col_hadamard_transform_cast_fusion.cu
multi_tensor/compute_scale.cu
recipe/mxfp8_scaling.cu
transpose/quantize_transpose_square_blockwise.cu
Expand Down
6 changes: 3 additions & 3 deletions transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -809,9 +809,9 @@ void group_quantize(const GroupedTensor *input, const GroupedTensor *activations
"First dimension of a grouped tensor should be divisible by 128.");
}

const int64_t *const offsets_ptr = reinterpret_cast<const int64_t *>(input->tensor_offsets.dptr);
const int64_t *const first_dims_ptr = reinterpret_cast<const int64_t *>(input->first_dims.dptr);
const int64_t *const last_dims_ptr = reinterpret_cast<const int64_t *>(input->last_dims.dptr);
const int64_t *const offsets_ptr = reinterpret_cast<const int64_t *>(output->tensor_offsets.dptr);
const int64_t *const first_dims_ptr = reinterpret_cast<const int64_t *>(output->first_dims.dptr);
const int64_t *const last_dims_ptr = reinterpret_cast<const int64_t *>(output->last_dims.dptr);

float *const workspace_ptr = IS_DBIAS ? reinterpret_cast<float *>(workspace->data.dptr) : nullptr;
float *const amax_ptr = reinterpret_cast<float *>(output->amax.dptr);
Expand Down
Loading
Loading