-
Notifications
You must be signed in to change notification settings - Fork 631
Add NVTE_KEEP_BACKWARD_UNQUANTIZED #2644
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Ziang Li <ziangli@umich.edu>
for more information, see https://pre-commit.ci
Greptile OverviewGreptile SummaryThis PR adds an opt-in path to run FP8/FP4 quantized forward while keeping backward in higher precision, primarily controlled through a new In the ops fuser stack, forward fusion functions are migrated to a unified Confidence Score: 3/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant U as User code
participant A as te.autocast/FP8GlobalStateManager
participant M as TE Module (Linear/Grouped/LNLinear)
participant F as OperationFuser (fusible ops)
participant UB as UserbuffersForwardLinear
participant BL as BasicLinear._functional_forward
U->>A: enter autocast(recipe)
A-->>M: fp8 enabled + recipe.quantize_backward
U->>M: forward(inp)
M->>M: prepare_forward(inp)
Note over M: keep_backward_unquantized = fp8 && !recipe.quantize_backward
alt BasicLinear / non-UB fused paths
M->>F: run pipeline
F->>BL: _functional_forward(..., keep_backward_unquantized)
BL-->>F: output, x_local, w
F->>F: save_for_backward(saved_input/saved_weight)
Note over F: if keep_backward_unquantized, save high-precision tensors
else Userbuffers fused forward selected
M->>F: run pipeline
F->>UB: fuser_forward(...)
UB-->>F: output, x_local, w
F->>F: save_for_backward(x_local, w)
Note over UB,F: UB path does not incorporate keep_backward_unquantized
end
U-->>M: loss.backward()
M->>F: autograd backward
alt keep_backward_unquantized honored
F->>BL: backward uses saved high-precision tensors
Note over BL: quantized backward disabled
else UB path
F->>UB: backward expects quantized tensors / quantized compute
end
U->>A: exit autocast
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
6 files reviewed, no comments
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
6 files reviewed, no comments
|
I'll work on potential unit test breakage. |
Signed-off-by: Ziang Li <ziangli@umich.edu>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
5 files reviewed, no comments
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
5 files reviewed, 4 comments
Signed-off-by: Ziang Li <ziangli@umich.edu>
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
4 files reviewed, 1 comment
Signed-off-by: Ziang Li <ziangli@umich.edu>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
5 files reviewed, 1 comment
… is used Signed-off-by: Ziang Li <ziangli@umich.edu>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
5 files reviewed, 2 comments
| ln_out_return = None | ||
| if return_layernorm_output or return_layernorm_output_gathered: | ||
| ln_out_return = ln_out | ||
| ln_out_hp = ln_out if keep_backward_unquantized else None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
storing both ln_out (quantized) and ln_out_hp (high precision) doubles the memory footprint for this activation
verify this memory overhead is acceptable for your target models, especially during training with large batch sizes or long sequences
Signed-off-by: Ziang Li <ziangli@umich.edu>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
4 files reviewed, no comments
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
5 files reviewed, no comments
Signed-off-by: Ziang Li <ziangli@umich.edu>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
5 files reviewed, no comments
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
5 files reviewed, 3 comments
Signed-off-by: Ziang Li <ziangli@umich.edu>
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
6 files reviewed, 2 comments
Signed-off-by: Ziang Li <ziangli@umich.edu>
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
4 files reviewed, no comments
Signed-off-by: Ziang Li <ziangli@umich.edu>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
5 files reviewed, 2 comments
| if enabled or calibrating: | ||
| _validate_recipe_quantization_flags(fp8_recipe) | ||
| quantize_forward = getattr(fp8_recipe, "quantize_forward", True) | ||
| effective_enabled = enabled and quantize_forward |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not very sure if we should disable when quantize_forward is false
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
5 files reviewed, 1 comment
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
12 files reviewed, 2 comments
| # Check if FP8 is enabled | ||
| fp8_enabled = FP8GlobalStateManager.is_fp8_enabled() | ||
| quantize_forward = fp8_enabled and self._quantize_forward | ||
| quantize_backward = fp8_enabled and self._quantize_backward | ||
| quantize_backward = ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Recipe None crash
FP8GlobalStateManager.get_fp8_recipe() can be None (e.g., if FP8 is enabled but no recipe was set), so ...get_fp8_recipe().quantize_backward will raise an AttributeError. This makes Quantize.forward() crash in that configuration; please guard for None (e.g., default to quantized backward when recipe is unset, or explicitly error with a clearer message).
Signed-off-by: Ziang Li <ziangli@umich.edu>
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
6 files reviewed, 3 comments
| assert not ( | ||
| not self.quantize_forward and self.quantize_backward | ||
| ), "Invalid recipe configuration: quantize_backward=True requires quantize_forward=True." | ||
| assert self.quantize_backward, "Delayed scaling does not support quantize_backward=False." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
assertion prevents using NVTE_KEEP_BACKWARD_UNQUANTIZED=1 with DelayedScaling recipe - when env var is set, quantize_backward becomes False, making this assert fail and blocking the entire feature for this recipe type
| assert ( | ||
| not keep_backward_unquantized | ||
| ), "NVTE_KEEP_BACKWARD_UNQUANTIZED is not implemented in LayerNormMLP" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
assertion makes LayerNormMLP crash immediately when NVTE_KEEP_BACKWARD_UNQUANTIZED=1 is set - if this module doesn't support the feature, either implement it or handle more gracefully with a clear error message before reaching this point
Additional Comments (1)
|
Signed-off-by: Ziang Li <ziangli@umich.edu>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
10 files reviewed, 1 comment
|
|
||
| # Save state for backward pass | ||
| if ctx.requires_grad: | ||
| saved_input = input_ if keep_backward_unquantized else x_local | ||
| saved_weight = self.weight if keep_backward_unquantized else w | ||
| if is_cpu_offload_enabled(): | ||
| mark_activation_offload(x_local) | ||
| ctx.save_for_backward(x_local, w) | ||
| ctx.with_quantized_compute = with_quantized_compute | ||
| mark_activation_offload(saved_input) | ||
| ctx.save_for_backward(saved_input, saved_weight) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unnecessary saved tensors
The forward path now saves saved_input/saved_weight whenever ctx.requires_grad is true, even when weight_requires_grad or input_requires_grad is false. In cases like frozen weights (common for finetune) or when only one side needs grads, this will save extra tensors and can materially increase activation memory. The prior if not weight_requires_grad: saved_input = None / if not input_requires_grad: saved_weight = None logic avoided that.
Signed-off-by: Ziang Li <ziangli@umich.edu>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
6 files reviewed, 1 comment
Additional Comments (1)
|
|
Currently, without |
Description
@HumansAnd
Add an NVTE_KEEP_BACKWARD_UNQUANTIZED env var for quantized fprop + high precision wgrad & dgrad.
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: