Skip to content

Graph Safe Current Scaling Support for GroupedLinear Module/Ops#3143

Open
vthumbe1503 wants to merge 3 commits into
NVIDIA:mainfrom
vthumbe1503:nvfp4_and_fp8_current_scaling
Open

Graph Safe Current Scaling Support for GroupedLinear Module/Ops#3143
vthumbe1503 wants to merge 3 commits into
NVIDIA:mainfrom
vthumbe1503:nvfp4_and_fp8_current_scaling

Conversation

@vthumbe1503

Copy link
Copy Markdown
Collaborator

Description

Please include a brief summary of the changes, relevant motivation and context.

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

vthumbe1503 and others added 3 commits June 25, 2026 00:40
Signed-off-by: Varun Thumbe <vthumbe@nvidia.com>
Removed details about FP8 current scaling methods.

Signed-off-by: vthumbe1503 <vthumbe@nvidia.com>
@vthumbe1503 vthumbe1503 marked this pull request as ready for review June 25, 2026 00:57
@vthumbe1503

Copy link
Copy Markdown
Collaborator Author

/te-ci pytorch

@greptile-apps

greptile-apps Bot commented Jun 25, 2026

Copy link
Copy Markdown
Contributor

Greptile Summary

This PR extends the graph-safe grouped-tensor / cuBLASLt GEMM path (_GroupedLinear module and GroupedLinear op) to support FP8 per-tensor current scaling (Float8CurrentScalingQuantizer) on both Hopper (CC 9.0) and Blackwell (CC 10.x/11.0), in addition to the existing MXFP8/NVFP4 Blackwell-only support. Corresponding test parametrization is added to cover the new recipe.

  • Adds an early-return guard for Float8CurrentScalingQuantizer ahead of the Blackwell-only MXFP8/NVFP4 compute-capability check in both the module (_GroupedLinear._is_grouped_tensor_supported) and the op (GroupedLinear._use_grouped_tensor_path).
  • Refines the rowwise_data / scale_inv memory-freeing logic to only clear those fields when columnwise_data is available — correctly implemented in ops/basic/grouped_linear.py, but the if condition was accidentally appended to a comment in module/grouped_linear.py, leaving the assignments unconditional.

Confidence Score: 3/5

Not safe to merge as-is: the module-level forward stores a None activation for the weight gradient computation on any non-FP8 grouped-tensor training pass where weights require gradients.

The module/grouped_linear.py change incorrectly embeds the if fp8 and grouped_x.columnwise_data is not None: guard inside a comment, leaving grouped_x.rowwise_data = None and grouped_x.scale_inv = None unconditional. In any non-FP8 (BF16/FP16) training run that takes the grouped-tensor path with weight_requires_grad=True, the activation buffer saved for backward will be None, breaking weight gradient computation. The parallel fix in ops/basic/grouped_linear.py is correct, so only the module path is affected — but that path covers the existing BF16/FP16 training workload on Hopper and Blackwell.

transformer_engine/pytorch/module/grouped_linear.py lines 335–338 need immediate attention; tests/pytorch/test_grouped_mlp.py has a minor gap in Hopper coverage for the new fp8_current_scaling case.

Important Files Changed

Filename Overview
transformer_engine/pytorch/module/grouped_linear.py Adds Float8CurrentScalingQuantizer to the grouped-tensor path gate; however, the if fp8 and grouped_x.columnwise_data is not None: guard was accidentally placed inside a comment on line 336, causing rowwise_data and scale_inv to be cleared unconditionally — breaks wgrad for BF16/FP16 when weight_requires_grad=True.
transformer_engine/pytorch/ops/basic/grouped_linear.py Correctly adds Float8CurrentScalingQuantizer early-return in _use_grouped_tensor_path and adds the columnwise_data is not None guard before clearing rowwise_data; logic is sound and symmetric with the module-level change.
tests/pytorch/test_grouped_linear.py Properly adds Float8CurrentScaling test variant with its own fp8_available skip guard, adjusts hardware capability checks to allow FP8 current scaling on Hopper, and updates test IDs.
tests/pytorch/test_grouped_mlp.py Adds fp8_current_scaling and nvfp4_rht quantization variants, but the blanket SM100+ device skip causes the new fp8_current_scaling case to be silently skipped on Hopper, providing no coverage for the newly supported SM90 path.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A[GroupedLinear forward] --> B{use_grouped_tensor_path?}
    B -->|CC check 9.0-11.0 passes| C{fp8?}
    B -->|CC out of range| Z[Legacy path]
    C -->|Yes| D{All Float8CurrentScalingQuantizer?}
    D -->|Yes| E[Return True — Hopper + Blackwell supported]
    D -->|No| F{CC >= 10.0 Blackwell?}
    F -->|Yes| G{All MXFP8 or NVFP4+RHT?}
    G -->|Yes| H[Return True — Blackwell only]
    G -->|No| Z
    F -->|No| Z
    C -->|No BF16/FP16| I{dtype BF16 or FP16?}
    I -->|Yes| J[Return True]
    I -->|No| Z
    E --> K[grouped_x = tex.group_quantize with columnwise=weight_requires_grad]
    H --> K
    J --> L[grouped_x = GroupedTensorStorage rowwise only]
    K --> M[general_grouped_gemm_for_grouped_tensor]
    L --> M
    M --> N{is_grad_enabled and weight_requires_grad?}
    N -->|ops path: with_quantized_compute and columnwise_data != None| O[Free rowwise_data — correct]
    N -->|module path: BUGGY unconditional| P[Free rowwise_data — breaks BF16/FP16 wgrad]
    O --> Q[save_for_backward]
    P --> Q
Loading
%%{init: {'theme': 'base', 'themeVariables': {"darkMode": true, "background": "#0d1117", "primaryColor": "#21262d", "primaryTextColor": "#e6edf3", "primaryBorderColor": "#8b949e", "lineColor": "#8b949e", "textColor": "#e6edf3", "edgeLabelBackground": "#161b22", "actorBkg": "#21262d", "actorBorder": "#8b949e", "actorTextColor": "#e6edf3", "actorLineColor": "#8b949e", "signalColor": "#8b949e", "signalTextColor": "#e6edf3", "noteBkgColor": "#373320", "noteBorderColor": "#d4a72c", "noteTextColor": "#f0e6c0", "labelBoxBkgColor": "#21262d", "labelBoxBorderColor": "#8b949e", "labelTextColor": "#e6edf3", "loopTextColor": "#e6edf3", "activationBkgColor": "#30363d", "activationBorderColor": "#8b949e"}}}%%
flowchart TD
    A[GroupedLinear forward] --> B{use_grouped_tensor_path?}
    B -->|CC check 9.0-11.0 passes| C{fp8?}
    B -->|CC out of range| Z[Legacy path]
    C -->|Yes| D{All Float8CurrentScalingQuantizer?}
    D -->|Yes| E[Return True — Hopper + Blackwell supported]
    D -->|No| F{CC >= 10.0 Blackwell?}
    F -->|Yes| G{All MXFP8 or NVFP4+RHT?}
    G -->|Yes| H[Return True — Blackwell only]
    G -->|No| Z
    F -->|No| Z
    C -->|No BF16/FP16| I{dtype BF16 or FP16?}
    I -->|Yes| J[Return True]
    I -->|No| Z
    E --> K[grouped_x = tex.group_quantize with columnwise=weight_requires_grad]
    H --> K
    J --> L[grouped_x = GroupedTensorStorage rowwise only]
    K --> M[general_grouped_gemm_for_grouped_tensor]
    L --> M
    M --> N{is_grad_enabled and weight_requires_grad?}
    N -->|ops path: with_quantized_compute and columnwise_data != None| O[Free rowwise_data — correct]
    N -->|module path: BUGGY unconditional| P[Free rowwise_data — breaks BF16/FP16 wgrad]
    O --> Q[save_for_backward]
    P --> Q
Loading

Comments Outside Diff (1)

  1. tests/pytorch/test_grouped_mlp.py, line 481-483 (link)

    P2 fp8_current_scaling test case always skipped on Hopper

    The blanket device_capability < (10, 0) skip applies to every quantization value, including the newly added fp8_current_scaling. However, per the implementation in ops/basic/grouped_linear.py, Float8CurrentScalingQuantizer is explicitly supported on Hopper (CC 9.0) via the grouped GEMM path — it return True before the Blackwell-only check. Because of this guard, the fp8_current_scaling parametrized case will be unconditionally skipped on SM90 hardware and the coverage gained by adding it is limited to Blackwell-only CI runs.

Reviews (1): Last reviewed commit: "Unecessary details remove" | Re-trigger Greptile

Comment on lines +335 to +338
# Free Rowwise Data if columnwise data is available for backward pass
# (For FP8 per tensor current scaling on Hopper) if fp8 and grouped_x.columnwise_data is not None:
grouped_x.rowwise_data = None
grouped_x.scale_inv = None

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Conditional guard accidentally embedded in comment — rowwise_data cleared unconditionally

The if fp8 and grouped_x.columnwise_data is not None: guard was intended to precede the two assignments on lines 337–338, but it was appended to the end of the preceding comment on line 336. As Python ignores everything after #, both grouped_x.rowwise_data = None and grouped_x.scale_inv = None now execute unconditionally whenever is_grad_enabled and weight_requires_grad are both True.

For the non-FP8 (BF16/FP16) grouped-tensor path, grouped_x.rowwise_data holds the packed activation buffer that is saved for backward and used to compute the weight gradient. Clearing it to None before ctx.save_for_backward destroys the activation data, causing the wgrad computation to operate on None — resulting in a crash or silently incorrect gradients.

The equivalent change in ops/basic/grouped_linear.py (line 1335) correctly places the condition on its own line: if with_quantized_compute and grouped_x.columnwise_data is not None:.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant