-
Notifications
You must be signed in to change notification settings - Fork 634
[PyTorch] GroupedTensor integration
#2600
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Draft
ksivaman
wants to merge
34
commits into
NVIDIA:main
Choose a base branch
from
ksivaman:grouped_tensor_python
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Draft
Changes from all commits
Commits
Show all changes
34 commits
Select commit
Hold shift + click to select a range
5e39835
Python GroupedTensor and contiguous weights for GroupedLinear
ksivaman 66e7d7f
Merge branch 'main' into grouped_tensor_python
ksivaman 40c619e
Graph safe C API for grouped RHT, needs testing
ksivaman cf61339
Merge branch 'main' into grouped_tensor_python
ksivaman 759e7bb
C++ utils, untested
ksivaman 1d09c2a
Merge branch 'main' into grouped_tensor_python
vthumbe1503 e1b65ac
Merge branch 'main' into grouped_tensor_python
ksivaman 3ba639e
Pytorch Binding for GroupedTensor APIs (#13)
vthumbe1503 ebf2194
Merge branch 'main' into grouped_tensor_python
ksivaman 4337520
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 5ab30f5
Fix make grouped tensor api
ksivaman 05dab12
Merge branch 'main' into grouped_tensor_python
ksivaman 68ce836
Fixes to tests
ksivaman 3e7859c
PyTorch-Python GroupedTensor
ksivaman 53c38ec
Merge branch 'main' into grouped_tensor_python
ksivaman d57651d
Fix test
ksivaman bd41fd0
All tests pass
ksivaman 351b74d
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] fd8ce0f
Update transformer_engine/pytorch/tensor/storage/grouped_tensor.py
ksivaman 24cfd8c
Remove mxfp8 gq test
ksivaman 97a1f33
C++ PyTorch GroupedTensor changes WIP
ksivaman 82f7ebe
Merge branch 'main' into pytorch_python_grouped_tensor
ksivaman e1788b3
Compiles
ksivaman 9022383
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] 11095a9
Fix runtime failure for test
ksivaman 373d9e3
Fix IMA in mxfp8 GQ
ksivaman 1601960
Add CG test for grouped_quantize
ksivaman bd57000
Fix recipe tests and FP8 weights
ksivaman 91ab416
Fix recipe tests and FP8 weights
ksivaman 52ab0ed
Merge branch 'main' into pytorch_python_grouped_tensor
ksivaman a5de7a5
Fix device test
ksivaman 77fa728
Disable grouped weights for unsupported recipes
ksivaman 9009f75
Merge branch 'pytorch_python_grouped_tensor' into grouped_tensor_python
ksivaman bea794f
Merge branch 'main' into grouped_tensor_python
ksivaman File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -55,7 +55,7 @@ | |
|
|
||
|
|
||
| def make_quantizer(quantization: str, num_tensors: int, shape: List[Tuple[int, int]]) -> Quantizer: | ||
| """Create quantizers for given quantization scheme""" | ||
| """Create quantizer for given quantization scheme""" | ||
|
|
||
| if quantization == "fp8_delayed_scaling": | ||
| quantizer = Float8Quantizer( | ||
|
|
@@ -203,12 +203,12 @@ def test_split_into_quantized_tensors_quantized(self, quantization: str) -> None | |
| """Test split_into_quantized_tensors for quantized tensors""" | ||
| num_tensors = 3 | ||
| shape = [(512, 512) for _ in range(num_tensors)] | ||
| quantizers = make_quantizer(quantization, num_tensors, shape) | ||
| quantizer = make_quantizer(quantization, num_tensors, shape) | ||
|
|
||
| grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes( | ||
| num_tensors=num_tensors, | ||
| shape=shape, | ||
| quantizer=quantizers, | ||
| quantizer=quantizer, | ||
| device="cuda", | ||
| ) | ||
|
|
||
|
|
@@ -260,12 +260,12 @@ def test_quantize_inplace(self, quantization: str) -> None: | |
| """Test that quantize is done in-place for all recipes""" | ||
| num_tensors = 3 | ||
| shape = [(512, 512) for _ in range(num_tensors)] | ||
| quantizers = make_quantizer(quantization, num_tensors, shape) | ||
| quantizer = make_quantizer(quantization, num_tensors, shape) | ||
|
|
||
| grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes( | ||
| num_tensors=num_tensors, | ||
| shape=shape, | ||
| quantizer=quantizers, | ||
| quantizer=quantizer, | ||
| device="cuda", | ||
| ) | ||
|
|
||
|
|
@@ -300,12 +300,12 @@ def test_quantize_varying_shapes(self, quantization: str) -> None: | |
| """Test quantize with varying shapes""" | ||
| num_tensors = 3 | ||
| shape = [(256, 512), (512, 512), (768, 512)] | ||
| quantizers = make_quantizer(quantization, num_tensors, shape) | ||
| quantizer = make_quantizer(quantization, num_tensors, shape) | ||
|
|
||
| grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes( | ||
| num_tensors=num_tensors, | ||
| shape=shape, | ||
| quantizer=quantizers, | ||
| quantizer=quantizer, | ||
| device="cuda", | ||
| ) | ||
|
|
||
|
|
@@ -334,15 +334,15 @@ def test_static_quantize_method(self, quantization: str) -> None: | |
| """Test the static quantize method""" | ||
| num_tensors = 3 | ||
| shape = [(512, 512) for _ in range(num_tensors)] | ||
| quantizers = make_quantizer(quantization, num_tensors, shape) | ||
| quantizer = make_quantizer(quantization, num_tensors, shape) | ||
|
|
||
| # Create input tensors | ||
| input_tensors = [torch.randn(s, dtype=torch.float32, device="cuda") for s in shape] | ||
|
|
||
| # Use static quantize method | ||
| grouped_tensor = GroupedTensor.create_and_quantize( | ||
| tensors=input_tensors, | ||
| quantizer=quantizers, | ||
| quantizer=quantizer, | ||
| device="cuda", | ||
| ) | ||
|
|
||
|
|
@@ -361,6 +361,99 @@ def test_static_quantize_method(self, quantization: str) -> None: | |
| expected_offset = _rowwise_offset_bytes(i * numel, quantization) | ||
| assert rowwise_data.data_ptr() == original_data_ptr + expected_offset | ||
|
|
||
| @pytest.mark.parametrize( | ||
| "shape", | ||
| [[(256, 512), (512, 512), (768, 512)], [(512, 512), (512, 512), (512, 512)]], | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. please add edge cases to the test:
|
||
| ) | ||
| @pytest.mark.skipif(not mxfp8_available, reason=reason_for_no_mxfp8) | ||
| def test_quantize_grouped_mxfp8(self, shape: List[Tuple[int, int]]) -> None: | ||
| """Test grouped quantization for MXFP8 against per-tensor quantization.""" | ||
| # Test wont pass until the grouped quantization PR from Oleg is merged. | ||
| num_tensors = 2 | ||
| shape = [(512, 1024) for _ in range(num_tensors)] | ||
|
|
||
| # Create BF16 input tensors and pack into a 2D tensor | ||
| input_tensors = [torch.randn(s, dtype=torch.bfloat16, device="cuda") for s in shape] | ||
| quantized_tensors = [ | ||
| MXFP8Quantizer(fp8_dtype=tex.DType.kFloat8E4M3)(tensor) for tensor in input_tensors | ||
| ] | ||
| grouped_input = torch.cat(input_tensors, dim=0) | ||
|
|
||
| # Create MXFP8 output grouped tensor (rowwise only for easier validation) | ||
| quantizer = MXFP8Quantizer(fp8_dtype=tex.DType.kFloat8E4M3) | ||
| quantizer.set_usage(rowwise=True, columnwise=False) | ||
| first_dims = torch.tensor( | ||
| [shape[0][0] for _ in range(num_tensors)], | ||
| dtype=torch.int64, | ||
| device="cuda", | ||
| ) | ||
|
|
||
| # Quantize using grouped API | ||
| grouped_output = tex.group_quantize( | ||
| grouped_input, | ||
| quantizer, | ||
| num_tensors, | ||
| first_dims, | ||
| ) | ||
| # Build expected output by quantizing each tensor independently | ||
| expected_data = [] | ||
| expected_scale_inv = [] | ||
| for tensor in input_tensors: | ||
| qtensor = quantizer(tensor) | ||
| expected_data.append(qtensor._rowwise_data.reshape(-1)) | ||
| expected_scale_inv.append(qtensor._rowwise_scale_inv.reshape(-1)) | ||
|
|
||
| expected_data = torch.cat(expected_data) | ||
| expected_scale_inv = torch.cat(expected_scale_inv) | ||
|
|
||
| assert torch.equal(grouped_output.data, expected_data) | ||
| assert torch.equal(grouped_output.scale_inv, expected_scale_inv) | ||
|
|
||
| @pytest.mark.skipif(not mxfp8_available, reason=reason_for_no_mxfp8) | ||
| def test_group_quantize_cudagraph_capturable(self) -> None: | ||
| """Ensure group_quantize is CUDA graph capturable.""" | ||
| num_tensors = 2 | ||
| shape = [(512, 1024) for _ in range(num_tensors)] | ||
| input_tensors = [torch.randn(s, dtype=torch.bfloat16, device="cuda") for s in shape] | ||
| grouped_input = torch.cat(input_tensors, dim=0) | ||
|
|
||
| quantizer = MXFP8Quantizer(fp8_dtype=tex.DType.kFloat8E4M3) | ||
| quantizer.set_usage(rowwise=True, columnwise=False) | ||
| first_dims = torch.tensor( | ||
| [shape[0][0] for _ in range(num_tensors)], | ||
| dtype=torch.int64, | ||
| device="cuda", | ||
| ) | ||
|
|
||
| torch.cuda.synchronize() | ||
| static_input = grouped_input.clone() | ||
| static_first_dims = first_dims.clone() | ||
|
|
||
| # Warmup to initialize kernels and allocator state | ||
| _ = tex.group_quantize(static_input, quantizer, num_tensors, static_first_dims) | ||
| torch.cuda.synchronize() | ||
|
|
||
| graph = torch.cuda.CUDAGraph() | ||
| with torch.cuda.graph(graph): | ||
| static_output = tex.group_quantize( | ||
| static_input, | ||
| quantizer, | ||
| num_tensors, | ||
| static_first_dims, | ||
| ) | ||
|
|
||
| fresh_input = torch.cat( | ||
| [torch.randn(s, dtype=torch.bfloat16, device="cuda") for s in shape], | ||
| dim=0, | ||
| ) | ||
| static_input.copy_(fresh_input) | ||
| graph.replay() | ||
| torch.cuda.synchronize() | ||
|
|
||
| expected = tex.group_quantize(static_input, quantizer, num_tensors, static_first_dims) | ||
| assert torch.equal(static_output.data, expected.data) | ||
| assert torch.equal(static_output.scale_inv, expected.scale_inv) | ||
|
|
||
| def test_clear(self) -> None: | ||
| """Test clear method""" | ||
| num_tensors = 3 | ||
|
|
||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We don't need "num_tensors" as an argument here anymore, I think, because we assume all tensors in the group use the same kind of quantizer.