-
Notifications
You must be signed in to change notification settings - Fork 719
[PyTorch] GroupedTensor integration
#2600
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from 1 commit
Commits
Show all changes
46 commits
Select commit
Hold shift + click to select a range
5e39835
Python GroupedTensor and contiguous weights for GroupedLinear
ksivaman 66e7d7f
Merge branch 'main' into grouped_tensor_python
ksivaman 40c619e
Graph safe C API for grouped RHT, needs testing
ksivaman cf61339
Merge branch 'main' into grouped_tensor_python
ksivaman 759e7bb
C++ utils, untested
ksivaman 1d09c2a
Merge branch 'main' into grouped_tensor_python
vthumbe1503 e1b65ac
Merge branch 'main' into grouped_tensor_python
ksivaman 3ba639e
Pytorch Binding for GroupedTensor APIs (#13)
vthumbe1503 ebf2194
Merge branch 'main' into grouped_tensor_python
ksivaman 4337520
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 5ab30f5
Fix make grouped tensor api
ksivaman 05dab12
Merge branch 'main' into grouped_tensor_python
ksivaman 68ce836
Fixes to tests
ksivaman 3e7859c
PyTorch-Python GroupedTensor
ksivaman 53c38ec
Merge branch 'main' into grouped_tensor_python
ksivaman d57651d
Fix test
ksivaman bd41fd0
All tests pass
ksivaman 351b74d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] fd8ce0f
Update transformer_engine/pytorch/tensor/storage/grouped_tensor.py
ksivaman 24cfd8c
Remove mxfp8 gq test
ksivaman 97a1f33
C++ PyTorch GroupedTensor changes WIP
ksivaman 82f7ebe
Merge branch 'main' into pytorch_python_grouped_tensor
ksivaman e1788b3
Compiles
ksivaman 9022383
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 11095a9
Fix runtime failure for test
ksivaman 373d9e3
Fix IMA in mxfp8 GQ
ksivaman 1601960
Add CG test for grouped_quantize
ksivaman bd57000
Fix recipe tests and FP8 weights
ksivaman 91ab416
Fix recipe tests and FP8 weights
ksivaman 52ab0ed
Merge branch 'main' into pytorch_python_grouped_tensor
ksivaman a5de7a5
Fix device test
ksivaman 77fa728
Disable grouped weights for unsupported recipes
ksivaman 9009f75
Merge branch 'pytorch_python_grouped_tensor' into grouped_tensor_python
ksivaman bea794f
Merge branch 'main' into grouped_tensor_python
ksivaman 6b0c420
Merge branch 'main' into grouped_tensor_python
ksivaman e3278dd
Merge branch 'main' into grouped_tensor_python
ksivaman 864c484
Integrate NVFP4 Graph Safe Group Quantize (#14)
zhongbozhu 4ee0339
improve mxfp8 unit test
zhongbozhu 9f5f24c
pre-swizzle nvfp4 mxfp8 for MoE
zhongbozhu 22f8a5b
avoid having nvte_get_grouped_tensor_param_v2
zhongbozhu 63e1563
more tests
zhongbozhu 4d66324
fix group quantize mxfp8 kernel
zhongbozhu 621fb0e
Relaxed restriction for the last dim to be a multiple of 128
Oleg-Goncharov 439c933
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 3de9850
Merge branch 'main' into grouped_tensor_python
ksivaman 5bf0cf9
Merge branch 'main' into grouped_tensor_python
ksivaman File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Some comments aren't visible on the classic Files Changed page.
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -13,6 +13,7 @@ | |
| import transformer_engine_torch as tex | ||
|
|
||
| from transformer_engine.common.recipe import Recipe | ||
| from transformer_engine.pytorch.tensor.storage.grouped_tensor import GroupedTensor | ||
| from .base import ( | ||
| get_dummy_wgrad, | ||
| TransformerEngineBaseModule, | ||
|
|
@@ -147,7 +148,10 @@ def forward( | |
| # tensors (like scales), but bulk allocation shares storage across all tensors, | ||
| # so if scales can't be offloaded, nothing in the group can be offloaded. | ||
| inputmats = tex.split_quantize( | ||
| inp_view, m_splits, input_quantizers, disable_bulk_allocation=cpu_offloading | ||
| inp_view, | ||
| m_splits, | ||
| input_quantizers, | ||
| disable_bulk_allocation=cpu_offloading, | ||
| ) | ||
| elif debug: | ||
| inputmats = DebugQuantizer.multi_tensor_quantize( | ||
|
|
@@ -365,7 +369,10 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], | |
| for i in range(ctx.num_gemms): | ||
| grad_biases[i] = grad_output_mats[i].sum(dim=0) | ||
| grad_output = DebugQuantizer.multi_tensor_quantize( | ||
| grad_output_view, ctx.grad_output_quantizers, ctx.m_splits, ctx.activation_dtype | ||
| grad_output_view, | ||
| ctx.grad_output_quantizers, | ||
| ctx.m_splits, | ||
| ctx.activation_dtype, | ||
| ) | ||
| else: | ||
| # Only split grad output. Grad bias is fused with | ||
|
|
@@ -436,7 +443,8 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], | |
| if ctx.input_quantizers[0] is not None: | ||
| for input_quantizer in ctx.input_quantizers: | ||
| if isinstance( | ||
| input_quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer) | ||
| input_quantizer, | ||
| (Float8Quantizer, Float8CurrentScalingQuantizer), | ||
| ): | ||
| input_quantizer.set_usage(rowwise=True, columnwise=True) | ||
| else: | ||
|
|
@@ -446,7 +454,10 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], | |
| inputmats = tex.split_quantize(inp_view, ctx.m_splits, ctx.input_quantizers) | ||
| elif ctx.debug: | ||
| inputmats = DebugQuantizer.multi_tensor_quantize( | ||
| inp_view, ctx.input_quantizers, ctx.m_splits, ctx.activation_dtype | ||
| inp_view, | ||
| ctx.input_quantizers, | ||
| ctx.m_splits, | ||
| ctx.activation_dtype, | ||
| ) | ||
| else: | ||
| inputmats = torch.split( | ||
|
|
@@ -616,7 +627,7 @@ def __init__( | |
| ) -> None: | ||
| super().__init__() | ||
|
|
||
| params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype | ||
| self.params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype | ||
| self.num_gemms = num_gemms | ||
| self.in_features = in_features | ||
| self.out_features = out_features | ||
|
|
@@ -631,13 +642,20 @@ def __init__( | |
| assert ( | ||
| not ub_overlap_rs and not ub_overlap_ag | ||
| ), "GroupedLinear doesn't support Userbuffer overlap." | ||
| self.init_method = init_method | ||
| self.get_rng_state_tracker = get_rng_state_tracker | ||
| self.rng_tracker_name = rng_tracker_name | ||
| self.name = name | ||
|
|
||
| self.wgrad_store = WeightGradStore(delay_wgrad_compute) | ||
|
|
||
| self._offsets = {"input": 0, "weight": 1, "output": 2, "grad_output": 0, "grad_input": 1} | ||
| self._offsets = { | ||
| "input": 0, | ||
| "weight": 1, | ||
| "output": 2, | ||
| "grad_output": 0, | ||
| "grad_input": 1, | ||
| } | ||
| self._num_fp8_tensors_per_gemm = { | ||
| "fwd": 3, | ||
| "bwd": 2, | ||
|
|
@@ -679,7 +697,7 @@ def __init__( | |
| self.out_features, | ||
| self.in_features, | ||
| device=device, | ||
| dtype=params_dtype, | ||
| dtype=self.params_dtype, | ||
| ), | ||
| ), | ||
| init_fn=init_method, | ||
|
|
@@ -695,20 +713,21 @@ def __init__( | |
| torch.empty( | ||
| self.out_features, | ||
| device=device, | ||
| dtype=params_dtype, | ||
| dtype=self.params_dtype, | ||
| ), | ||
| ), | ||
| init_fn=init_method_constant(0.0), | ||
| ) | ||
| else: | ||
| bias = torch.Tensor().to(dtype=params_dtype, device=device) | ||
| bias = torch.Tensor().to(dtype=self.params_dtype, device=device) | ||
| setattr(self, f"bias{i}", bias) | ||
|
|
||
| if self.primary_weights_in_fp8: | ||
| self.init_fp8_metadata(num_gemms=self.num_gemms) | ||
|
|
||
| is_meta = torch.device(device).type == "meta" | ||
| self.reset_parameters(defer_init=is_meta) | ||
| self.make_grouped_weights(defer_init=is_meta) | ||
|
|
||
| if self.wgrad_store.delay_wgrad_compute(): | ||
| for name, param in self.named_parameters(): | ||
|
|
@@ -729,8 +748,49 @@ def set_meta_tensor(self, fwd: bool, recipe: Recipe) -> None: | |
| ) | ||
| self._customize_quantizers_float8_current_scaling(fwd, recipe) | ||
|
|
||
| def make_grouped_weights(self, defer_init=False) -> None: | ||
| """ | ||
| Convert parameters into a GroupedTensor and re-register them as parameters. | ||
| """ | ||
|
|
||
| if defer_init: | ||
| return | ||
|
|
||
| weights = [getattr(self, f"weight{i}") for i in range(self.num_gemms)] | ||
| weight_quantizers = self._get_weight_quantizers() | ||
|
|
||
| # Create the weight storage. | ||
| grouped_weights = GroupedTensor.make_grouped_tensor( | ||
| num_tensors=self.num_gemms, | ||
| shape=[(self.out_features, self.in_features)] * self.num_gemms, | ||
| quantizers=weight_quantizers, | ||
| dtype=self.params_dtype, | ||
| ) | ||
|
|
||
| # Copy existing params into storage. | ||
| # TODO(ksivamani): Verify correctness of copy for all recipes. | ||
| with torch.no_grad(): | ||
| for i in range(self.num_gemms): | ||
| grouped_weights.quantized_tensors[i].copy_(weights[i]) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. style: check that the copy operation works correctly for all quantization recipes (FP8, MXFP8, NVFP4, block scaling). the TODO comment on line 771 acknowledges this needs verification. |
||
|
|
||
| # Re-register the grouped weights as parameters. | ||
| for i in range(self.num_gemms): | ||
| self.register_parameter( | ||
| f"weight{i}", | ||
| torch.nn.Parameter(grouped_weights.quantized_tensors[i]), | ||
| init_fn=self.init_method, | ||
| get_rng_state_tracker=self.get_rng_state_tracker, | ||
| fp8_meta_index=self._offsets["weight"] + i * self._num_fp8_tensors_per_gemm["fwd"], | ||
| ) | ||
|
|
||
| self.set_tensor_parallel_attributes(defer_init=defer_init) | ||
|
|
||
| def reset_parameters(self, defer_init=False): | ||
| super().reset_parameters(defer_init=defer_init) | ||
| self.set_tensor_parallel_attributes(defer_init=defer_init) | ||
|
|
||
| def set_tensor_parallel_attributes(self, defer_init=False) -> None: | ||
| """Set attributes needed for TP""" | ||
|
|
||
| if not defer_init: | ||
| # Set parallelism attributes for linear weights | ||
|
|
||
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.