Skip to content

Conversation

@ksivaman
Copy link
Member

@ksivaman ksivaman commented Feb 6, 2026

Description

Extracts the python pieces of GroupedTensor infrastructure from #2600. Since this is mainly focused on creation of weights as a single GroupedTensor and exposing them as multiple QuantizedTensors for PyTorch, this portion does not need to be graph capturable.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

  • Expose a python GroupedTensor class.
  • Integrate GroupedTensor into GroupedLinear such that the parameters are contiguous.

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Feb 6, 2026

Greptile Overview

Greptile Summary

This PR introduces a new Python GroupedTensor storage abstraction (transformer_engine/pytorch/tensor/storage/grouped_tensor.py) that can allocate a single contiguous backing buffer for a set of tensors (optionally quantized via existing TE quantizers) and then expose them as multiple per-tensor QuantizedTensor/QuantizedTensorStorage views. GroupedLinear is updated to re-register its per-GEMM weight parameters as views into a shared GroupedTensor, and tests are added/updated to validate construction, splitting, quantization, and contiguity of weight storage.

Key integration points are the new GroupedTensor.make_grouped_tensor_with_shapes(...) allocator/slicer and the GroupedLinear.reset_parameters() path, which now calls make_grouped_weights() to convert weights into a grouped contiguous layout.

Confidence Score: 2/5

  • This PR has merge-blocking issues in the new GroupedTensor implementation that can cause incorrect MXFP8 columnwise scale buffer sizing and a broken create_and_quantize API contract.
  • Score reduced due to a confirmed allocation/view mismatch for MXFP8 columnwise scale_inv (will break when columnwise scaling is enabled) and a clear signature/return-type mismatch in GroupedTensor.create_and_quantize that can break callers. Other changes appear reasonable and are covered by new tests, but these two issues should be fixed before merging.
  • transformer_engine/pytorch/tensor/storage/grouped_tensor.py

Important Files Changed

Filename Overview
tests/pytorch/test_grouped_tensor.py TODO
tests/pytorch/test_sanity.py TODO
transformer_engine/common/recipe/init.py TODO
transformer_engine/pytorch/module/grouped_linear.py TODO
transformer_engine/pytorch/tensor/float8_tensor.py TODO
transformer_engine/pytorch/tensor/mxfp8_tensor.py TODO
transformer_engine/pytorch/tensor/nvfp4_tensor.py TODO
transformer_engine/pytorch/tensor/storage/init.py TODO
transformer_engine/pytorch/tensor/storage/grouped_tensor.py TODO

Sequence Diagram

sequenceDiagram
    participant User as User/Trainer
    participant GL as GroupedLinear
    participant GT as GroupedTensor
    participant Q as Quantizer
    participant QT as QuantizedTensor views

    User->>GL: reset_parameters()
    GL->>GL: make_grouped_weights()
    GL->>GL: _get_weight_quantizers()
    GL->>GT: make_grouped_tensor_with_shapes(num_gemms, shapes, quantizer)
    GT->>GT: allocate contiguous buffers (data/scale_inv/etc)
    GT->>QT: split_into_quantized_tensors() (views into buffers)
    GL->>QT: copy_ existing per-weight params into views
    GL->>GL: register_parameter(weight{i} = Parameter(QT[i]))

    User->>GT: quantize(list[tensors])
    GT->>QT: split_into_quantized_tensors()
    GT->>Q: update_quantized(tensors[i], QT[i]) (in-place)
    GT-->>User: return per-tensor quantized views
Loading

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

5 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

@ksivaman ksivaman added the MoE label Feb 6, 2026
from .nvfp4_tensor_storage import NVFP4TensorStorage


class GroupedTensor:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think it's a good idea to put everything within a single class. We should have an abstract base class (GroupedTensorBase) and concrete classes like GroupedTensor (or UnquantizedGroupTensor?), MXFP8GroupedTensor, NVFP4GroupedTensor. The giant-pile-of-attrs design results in ugly implemenations (like the if-else blocks in make_grouped_tensor) and it generalizes poorly (columnwise_data is treated very differently between FP8 and MXFP8, enough that giving them the same name is questionable). We do use this design in the C++ grouped tensor class, but that should be viewed as a short-term expedient and not a long-term design (#2388 (comment)).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This ultimately depends on what we want to optimize for. If we believe that the majority of things we are going to write is going to be here is about "grouped" functionality that does not really care about the underlying format (or stuff where we could delegate that decision to C++ which has the full knowledge of the quantizer type and could implement things without huge if/else blocks) then it makes sense to have a single class here. If we believe that the majority of the functionality will be dependent on the quantization format then I agree that we should split this into multiple classes.
@ksivaman Can you comment on that?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think GroupedTensor in python should be a truthful copy of the C++ grouped tensor, so I do think it's okay to have a single class.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This GroupedTensor in python is used mostly as a storage class and is basically a copy of C GroupedTensor with some additional functionality for weight initialization and convenience. I think it's best to keep it simple and avoid over engineering at this stage. In the next steps when we implement a QuantizedGroupedTensor, say for FSDP2 + quantized parameter support, we could revisit if a small refactor would be helpful.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree matching the C++ class is reasonable for now so we can meet deadlines. In the long term, we should refactor both the C++ and Python class to have proper polymorphism. This will be a painful process.

@ksivaman ksivaman marked this pull request as draft February 6, 2026 21:28
columnwise_scale_inv = torch.empty(
total_columnwise_scale_elements, dtype=torch.uint8, device=device
)
elif quantizer._get_compatible_recipe().delayed():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we have gone for single quantizer, we should remove delayed scaling recipe & per-tensor current scaling for now since their quantizers are not stateless.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is what we discussed offline, but actually now that I think about it, this is used for FP8 parameters creation, so we cannot simply un-support recipes here. The correct method is to probably use multiple quantizers, or at least have a way for the user to supply multiple quantizers.

columnwise_scale_inv = torch.empty(
total_columnwise_scale_elements, dtype=torch.float32, device=device
)
elif quantizer._get_compatible_recipe().float8_current_scaling():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

float8_current_scaling can work with GroupedTensor once we refactored its implementation to remove the amax tensor out of its quantizer. Then it will be safe to put a single quantizer into the grouped tensor.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment as above^

result.append(tensor)

# Delayed scaling or current scaling (both use Float8TensorStorage)
elif recipe.delayed() or recipe.float8_current_scaling():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's assert an error for this case?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment as above^

dtype: Optional[torch.dtype] = None,
) -> GroupedTensor:
"""
Create a GroupedTensor for storing multiple weight tensors of the same shape.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor comment: Intent of this API is to create grouped tensor with variable first_dims/last_dims, so we can write that in the comment, since this is not going to be used to create weights.

Also the API can be named to make_grouped_tensor_graph_safe? So, people know this API is safe to use within a forward/backward of a module which we need to be cuda graphable

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This API is actually used to create the weights and is not graph safe (for now), which is fine as it's used 1 time during creation.

Copy link
Collaborator

@vthumbe1503 vthumbe1503 Feb 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought make_grouped_tensor_with_shapes is used create weights. Since weight's shapes are going to be constant. Whats the intent of make_grouped_tensor_with_shapes then?

And whats the API we are going to be using to create inputs? Dont we need graph safe for that one?

torch.zeros(1, device=first_dims.device, dtype=first_dims.dtype),
torch.cumsum(first_dims * logical_last_dim, dim=0),
]
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see the above comment to have single kernel and am not sure what your plan is to implement that.
But with torch op you can avoid one memory op using

tensor_offsets = torch.empty(num_tensors + 1, device=first_dims.device, dtype=first_dims.dtype)
torch.cumsum(first_dims * logical_last_dim, dim=0, out=tensor_offsets[1:])

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

torch.empty would do garbage initialization whereas we need tensor_offsets[0] to be explicitly 0, so either way we'd have to do multiple kernels if using pytorch ops. That's why the plan is to later add a small cuda kernel so that we can call it from the C++ extensions and also for Jax as well.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh my mistake using torch.zeros instead of torch.empty should do the trick. Sure cuda kernel later sounds good.

ksivaman and others added 3 commits February 9, 2026 23:20
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
@ksivaman
Copy link
Member Author

/te-ci L0

@ksivaman ksivaman marked this pull request as ready for review February 10, 2026 01:08
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

9 files reviewed, 2 comments

Edit Code Review Agent Settings | Greptile

Comment on lines +909 to +935
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
noop_flag: Optional[torch.Tensor] = None,
) -> Tuple[QuantizedTensorStorage, ...]:
"""
Quantize given tensors into quantized tensors with underlying
storage allocated in a GroupedTensor.
"""

grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes(
num_tensors=len(tensors),
shape=[t.shape for t in tensors],
quantizer=quantizer,
device=device,
dtype=dtype,
)

grouped_tensor.quantize(tensors, noop_flag=noop_flag)

return grouped_tensor

def quantize(
self,
tensors: List[torch.Tensor],
noop_flag: Optional[torch.Tensor] = None,
) -> Tuple[QuantizedTensorStorage, ...]:
"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Broken create_and_quantize API
create_and_quantize is annotated/declared as returning a Tuple[QuantizedTensorStorage, ...] and taking tensors: int, but the implementation uses len(tensors) / iterates tensors and returns the GroupedTensor instance (return grouped_tensor). Any caller relying on the annotated contract (tuple of quantized tensors) will break, and the tensors: int annotation is incompatible with the actual usage.

This should either return the grouped tensor’s split/quantized tensors (matching the annotation), or update the signature/return type to reflect that it returns a GroupedTensor and expects an iterable of tensors.

Comment on lines +469 to +479
if i < num_tensors - 1:
columnwise_scale_inv_offsets.append(total_columnwise_scale_elements)
columnwise_scale_inv = torch.empty(
total_columnwise_scale_elements, dtype=torch.uint8, device=device
)
elif quantizer._get_compatible_recipe().delayed():
if rowwise_usage:
# Allocate rowwise data buffer (1D flattened, uint8)
data = torch.empty(total_elements, dtype=torch.uint8, device=device)
# Scale inverse - one per tensor
scale_inv = torch.empty(num_tensors, dtype=torch.float32, device=device)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

MXFP8 columnwise scale size
In the MXFP8 columnwise_scale_inv_offsets allocation loop, scale_inv_shape = quantizer.get_scale_shape(s, False) is used even though this is the columnwise scale buffer. Later, split_into_quantized_tensors views columnwise_scale_inv with get_scale_shape(tensor_shape, True) (grouped_tensor.py:744-747). This mismatch will allocate the wrong number of elements and lead to incorrect views/out-of-bounds when columnwise scaling is enabled.

Use quantizer.get_scale_shape(s, True) when computing columnwise_scale_inv_offsets/total_columnwise_scale_elements.

Comment on lines +266 to +291
def __repr__(self) -> str:
"""String representation of the GroupedTensor."""
return (
f"GroupedTensor(num_tensors={self.num_tensors}, "
f"shape={self.shape}, "
f"logical_shape={self.logical_shape}, "
f"dtype={self.get_dtype()})"
)

def __str__(self) -> str:
"""User-friendly string representation."""
shape_info = []
if self.all_same_shape():
shape_info.append("uniform shape")
else:
if not self.all_same_first_dim():
shape_info.append("varying first dim")
if not self.all_same_last_dim():
shape_info.append("varying last dim")

return (
f"GroupedTensor with {self.num_tensors} tensors "
f"({', '.join(shape_info) if shape_info else 'uniform'}), "
f"logical_shape={self.logical_shape}, "
f"dtype={self.get_dtype()}"
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are those different?

)

@staticmethod
def make_grouped_tensor(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the fact that this is on the Python side and not C++ is intentional or a TODO? What is the actual usage of this call? Is it just a helper for the weights creation and meant to be changed later?

grouped_tensor.quantized_tensors = grouped_tensor.split_into_quantized_tensors()
return grouped_tensor

def split_into_quantized_tensors(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this mostly for debugging? In general it would be good to revisit the docs for the functions and indicate which ones we expect to be used in the typical case (and e.g. are graph safe) and which ones are for debug/not performant/not graph safe.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants