-
Notifications
You must be signed in to change notification settings - Fork 633
[PyTorch] Python GroupedTensor
#2654
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Greptile OverviewGreptile SummaryThis PR introduces a new Python Key integration points are the new Confidence Score: 2/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant User as User/Trainer
participant GL as GroupedLinear
participant GT as GroupedTensor
participant Q as Quantizer
participant QT as QuantizedTensor views
User->>GL: reset_parameters()
GL->>GL: make_grouped_weights()
GL->>GL: _get_weight_quantizers()
GL->>GT: make_grouped_tensor_with_shapes(num_gemms, shapes, quantizer)
GT->>GT: allocate contiguous buffers (data/scale_inv/etc)
GT->>QT: split_into_quantized_tensors() (views into buffers)
GL->>QT: copy_ existing per-weight params into views
GL->>GL: register_parameter(weight{i} = Parameter(QT[i]))
User->>GT: quantize(list[tensors])
GT->>QT: split_into_quantized_tensors()
GT->>Q: update_quantized(tensors[i], QT[i]) (in-place)
GT-->>User: return per-tensor quantized views
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
5 files reviewed, 1 comment
| from .nvfp4_tensor_storage import NVFP4TensorStorage | ||
|
|
||
|
|
||
| class GroupedTensor: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think it's a good idea to put everything within a single class. We should have an abstract base class (GroupedTensorBase) and concrete classes like GroupedTensor (or UnquantizedGroupTensor?), MXFP8GroupedTensor, NVFP4GroupedTensor. The giant-pile-of-attrs design results in ugly implemenations (like the if-else blocks in make_grouped_tensor) and it generalizes poorly (columnwise_data is treated very differently between FP8 and MXFP8, enough that giving them the same name is questionable). We do use this design in the C++ grouped tensor class, but that should be viewed as a short-term expedient and not a long-term design (#2388 (comment)).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This ultimately depends on what we want to optimize for. If we believe that the majority of things we are going to write is going to be here is about "grouped" functionality that does not really care about the underlying format (or stuff where we could delegate that decision to C++ which has the full knowledge of the quantizer type and could implement things without huge if/else blocks) then it makes sense to have a single class here. If we believe that the majority of the functionality will be dependent on the quantization format then I agree that we should split this into multiple classes.
@ksivaman Can you comment on that?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think GroupedTensor in python should be a truthful copy of the C++ grouped tensor, so I do think it's okay to have a single class.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This GroupedTensor in python is used mostly as a storage class and is basically a copy of C GroupedTensor with some additional functionality for weight initialization and convenience. I think it's best to keep it simple and avoid over engineering at this stage. In the next steps when we implement a QuantizedGroupedTensor, say for FSDP2 + quantized parameter support, we could revisit if a small refactor would be helpful.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree matching the C++ class is reasonable for now so we can meet deadlines. In the long term, we should refactor both the C++ and Python class to have proper polymorphism. This will be a painful process.
| columnwise_scale_inv = torch.empty( | ||
| total_columnwise_scale_elements, dtype=torch.uint8, device=device | ||
| ) | ||
| elif quantizer._get_compatible_recipe().delayed(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since we have gone for single quantizer, we should remove delayed scaling recipe & per-tensor current scaling for now since their quantizers are not stateless.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is what we discussed offline, but actually now that I think about it, this is used for FP8 parameters creation, so we cannot simply un-support recipes here. The correct method is to probably use multiple quantizers, or at least have a way for the user to supply multiple quantizers.
| columnwise_scale_inv = torch.empty( | ||
| total_columnwise_scale_elements, dtype=torch.float32, device=device | ||
| ) | ||
| elif quantizer._get_compatible_recipe().float8_current_scaling(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
float8_current_scaling can work with GroupedTensor once we refactored its implementation to remove the amax tensor out of its quantizer. Then it will be safe to put a single quantizer into the grouped tensor.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same comment as above^
| result.append(tensor) | ||
|
|
||
| # Delayed scaling or current scaling (both use Float8TensorStorage) | ||
| elif recipe.delayed() or recipe.float8_current_scaling(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's assert an error for this case?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same comment as above^
| dtype: Optional[torch.dtype] = None, | ||
| ) -> GroupedTensor: | ||
| """ | ||
| Create a GroupedTensor for storing multiple weight tensors of the same shape. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Minor comment: Intent of this API is to create grouped tensor with variable first_dims/last_dims, so we can write that in the comment, since this is not going to be used to create weights.
Also the API can be named to make_grouped_tensor_graph_safe? So, people know this API is safe to use within a forward/backward of a module which we need to be cuda graphable
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This API is actually used to create the weights and is not graph safe (for now), which is fine as it's used 1 time during creation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thought make_grouped_tensor_with_shapes is used create weights. Since weight's shapes are going to be constant. Whats the intent of make_grouped_tensor_with_shapes then?
And whats the API we are going to be using to create inputs? Dont we need graph safe for that one?
| torch.zeros(1, device=first_dims.device, dtype=first_dims.dtype), | ||
| torch.cumsum(first_dims * logical_last_dim, dim=0), | ||
| ] | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see the above comment to have single kernel and am not sure what your plan is to implement that.
But with torch op you can avoid one memory op using
tensor_offsets = torch.empty(num_tensors + 1, device=first_dims.device, dtype=first_dims.dtype)
torch.cumsum(first_dims * logical_last_dim, dim=0, out=tensor_offsets[1:])
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
torch.empty would do garbage initialization whereas we need tensor_offsets[0] to be explicitly 0, so either way we'd have to do multiple kernels if using pytorch ops. That's why the plan is to later add a small cuda kernel so that we can call it from the C++ extensions and also for Jax as well.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh my mistake using torch.zeros instead of torch.empty should do the trick. Sure cuda kernel later sounds good.
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
|
/te-ci L0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
9 files reviewed, 2 comments
| device: Optional[torch.device] = None, | ||
| dtype: Optional[torch.dtype] = None, | ||
| noop_flag: Optional[torch.Tensor] = None, | ||
| ) -> Tuple[QuantizedTensorStorage, ...]: | ||
| """ | ||
| Quantize given tensors into quantized tensors with underlying | ||
| storage allocated in a GroupedTensor. | ||
| """ | ||
|
|
||
| grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes( | ||
| num_tensors=len(tensors), | ||
| shape=[t.shape for t in tensors], | ||
| quantizer=quantizer, | ||
| device=device, | ||
| dtype=dtype, | ||
| ) | ||
|
|
||
| grouped_tensor.quantize(tensors, noop_flag=noop_flag) | ||
|
|
||
| return grouped_tensor | ||
|
|
||
| def quantize( | ||
| self, | ||
| tensors: List[torch.Tensor], | ||
| noop_flag: Optional[torch.Tensor] = None, | ||
| ) -> Tuple[QuantizedTensorStorage, ...]: | ||
| """ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Broken create_and_quantize API
create_and_quantize is annotated/declared as returning a Tuple[QuantizedTensorStorage, ...] and taking tensors: int, but the implementation uses len(tensors) / iterates tensors and returns the GroupedTensor instance (return grouped_tensor). Any caller relying on the annotated contract (tuple of quantized tensors) will break, and the tensors: int annotation is incompatible with the actual usage.
This should either return the grouped tensor’s split/quantized tensors (matching the annotation), or update the signature/return type to reflect that it returns a GroupedTensor and expects an iterable of tensors.
| if i < num_tensors - 1: | ||
| columnwise_scale_inv_offsets.append(total_columnwise_scale_elements) | ||
| columnwise_scale_inv = torch.empty( | ||
| total_columnwise_scale_elements, dtype=torch.uint8, device=device | ||
| ) | ||
| elif quantizer._get_compatible_recipe().delayed(): | ||
| if rowwise_usage: | ||
| # Allocate rowwise data buffer (1D flattened, uint8) | ||
| data = torch.empty(total_elements, dtype=torch.uint8, device=device) | ||
| # Scale inverse - one per tensor | ||
| scale_inv = torch.empty(num_tensors, dtype=torch.float32, device=device) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
MXFP8 columnwise scale size
In the MXFP8 columnwise_scale_inv_offsets allocation loop, scale_inv_shape = quantizer.get_scale_shape(s, False) is used even though this is the columnwise scale buffer. Later, split_into_quantized_tensors views columnwise_scale_inv with get_scale_shape(tensor_shape, True) (grouped_tensor.py:744-747). This mismatch will allocate the wrong number of elements and lead to incorrect views/out-of-bounds when columnwise scaling is enabled.
Use quantizer.get_scale_shape(s, True) when computing columnwise_scale_inv_offsets/total_columnwise_scale_elements.
| def __repr__(self) -> str: | ||
| """String representation of the GroupedTensor.""" | ||
| return ( | ||
| f"GroupedTensor(num_tensors={self.num_tensors}, " | ||
| f"shape={self.shape}, " | ||
| f"logical_shape={self.logical_shape}, " | ||
| f"dtype={self.get_dtype()})" | ||
| ) | ||
|
|
||
| def __str__(self) -> str: | ||
| """User-friendly string representation.""" | ||
| shape_info = [] | ||
| if self.all_same_shape(): | ||
| shape_info.append("uniform shape") | ||
| else: | ||
| if not self.all_same_first_dim(): | ||
| shape_info.append("varying first dim") | ||
| if not self.all_same_last_dim(): | ||
| shape_info.append("varying last dim") | ||
|
|
||
| return ( | ||
| f"GroupedTensor with {self.num_tensors} tensors " | ||
| f"({', '.join(shape_info) if shape_info else 'uniform'}), " | ||
| f"logical_shape={self.logical_shape}, " | ||
| f"dtype={self.get_dtype()}" | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why are those different?
| ) | ||
|
|
||
| @staticmethod | ||
| def make_grouped_tensor( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is the fact that this is on the Python side and not C++ is intentional or a TODO? What is the actual usage of this call? Is it just a helper for the weights creation and meant to be changed later?
| grouped_tensor.quantized_tensors = grouped_tensor.split_into_quantized_tensors() | ||
| return grouped_tensor | ||
|
|
||
| def split_into_quantized_tensors( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this mostly for debugging? In general it would be good to revisit the docs for the functions and indicate which ones we expect to be used in the typical case (and e.g. are graph safe) and which ones are for debug/not performant/not graph safe.
Description
Extracts the python pieces of
GroupedTensorinfrastructure from #2600. Since this is mainly focused on creation of weights as a singleGroupedTensorand exposing them as multipleQuantizedTensors for PyTorch, this portion does not need to be graph capturable.Type of change
Changes
GroupedTensorclass.GroupedTensorintoGroupedLinearsuch that the parameters are contiguous.Checklist: