Skip to content

Fix QDQ inference OOM issue.#1763

Open
changwangss wants to merge 2 commits into
mainfrom
wangchang/fix_oom
Open

Fix QDQ inference OOM issue.#1763
changwangss wants to merge 2 commits into
mainfrom
wangchang/fix_oom

Conversation

@changwangss
Copy link
Copy Markdown

@changwangss changwangss commented Apr 29, 2026

Description

This change addresses a CUDA OOM caused by STE implementations in the quantization path.

Previously, floor_ste (and similar helpers) used (q - x).detach() + x. On large tensors, that pattern creates extra full-size intermediates and can spike peak memory during forward.

The fix replaces these STEs with custom autograd.Function implementations: forward returns the quantized value directly, backward keeps straight-through gradients. Quantization behavior stays the same, but temporary tensor pressure is lower.

At operator level, temporary memory is reduced by roughly 2/3.

Type of Change

Bug fix

Related Issues

#1762

Fixes or relates to #

Signed-off-by: changwangss <chang1.wang@intel.com>
Copilot AI review requested due to automatic review settings April 29, 2026 05:50
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR addresses CUDA OOM during QDQ inference by replacing STE implementations that materialize large intermediates with custom torch.autograd.Function-based STEs, reducing peak temporary tensor allocations during quantization.

Changes:

  • Replaced (q - x).detach() + x STE patterns for round/floor/ceil with custom autograd Functions.
  • Replaced float8 cast STE patterns with custom autograd Functions (CUDA/CPU and HPU).
  • Updated NVFP UE5M3 STE cast to use a custom autograd Function.

Reviewed changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 3 comments.

File Description
auto_round/data_type/utils.py Introduces custom autograd-based STEs for rounding ops and float8 cast STEs to reduce peak memory.
auto_round/data_type/nvfp.py Introduces a custom autograd-based STE for UE5M3 casting to avoid extra intermediates.

Comment on lines +211 to +218
class _RoundSTE(torch.autograd.Function):
@staticmethod
def forward(ctx, x: torch.Tensor):
return torch.round(x)

@staticmethod
def backward(ctx, grad_output: torch.Tensor):
return grad_output
Copy link

Copilot AI Apr 29, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These STE helpers are now implemented via custom torch.autograd.Function, but there are no unit tests covering their backward semantics (e.g., verifying that round_ste/floor_ste/ceil_ste propagate gradients as identity and that the forward matches the corresponding rounding op). Adding small CPU tests for forward + torch.autograd.grad would help prevent silent regressions in quantization/tuning behavior.

Copilot uses AI. Check for mistakes.
def cast_to_ue5m3_ste(x):
return _UE5M3CastSTE.apply(x)


Copy link

Copilot AI Apr 29, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cast_to_ue5m3_ste now relies on a custom torch.autograd.Function, but there are no tests asserting the straight-through gradient behavior for this path. Consider adding a unit test that checks forward equivalence to cast_to_ue5m3 and that gradients through cast_to_ue5m3_ste are identity (within expected dtype tolerances).

Suggested change
def _validate_cast_to_ue5m3_ste():
"""Validate the STE contract for cast_to_ue5m3_ste.
This helper is intentionally kept local to the implementation so tests can
exercise the same forward and backward path used in production:
* forward output matches ``cast_to_ue5m3``
* backward pass behaves as an identity straight-through estimator
"""
test_specs = {
torch.float16: {"rtol": 1e-3, "atol": 1e-3},
torch.float32: {"rtol": 1e-6, "atol": 1e-6},
torch.bfloat16: {"rtol": 1e-2, "atol": 1e-2},
}
base = torch.tensor(
[-480.0, -31.5, -2.75, -0.5, 0.0, 0.375, 1.0, 7.5, 96.0, 57344.0],
dtype=torch.float32,
)
for dtype, tol in test_specs.items():
x = base.to(dtype).clone().detach().requires_grad_(True)
y_ref = cast_to_ue5m3(x.detach())
y_ste = cast_to_ue5m3_ste(x)
torch.testing.assert_close(y_ste.detach(), y_ref, **tol)
upstream_grad = torch.linspace(-1.0, 1.0, steps=x.numel(), dtype=torch.float32).to(dtype)
y_ste.backward(upstream_grad)
torch.testing.assert_close(x.grad, upstream_grad, **tol)

Copilot uses AI. Check for mistakes.
"""
fp8 = ((torch.ops.hpu.cast_to_fp8_v2(x, 1.0, False, False, torch.float8_e4m3fn)[0]).to(x.dtype) - x).detach() + x
return fp8
return _HpuFloat8CastSTE.apply(x, torch.float8_e4m3fn)
Copy link

Copilot AI Apr 29, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

float8_e4m3fnuz_hpu_ste is currently casting with torch.float8_e4m3fn, which conflicts with the function name and with the Gaudi2 FP8 path (which clips/scales using torch.float8_e4m3fnuz). This will change quantization behavior/range for the fnuz path. Use torch.float8_e4m3fnuz here to keep the dtype flavor consistent end-to-end.

Suggested change
return _HpuFloat8CastSTE.apply(x, torch.float8_e4m3fn)
return _HpuFloat8CastSTE.apply(x, torch.float8_e4m3fnuz)

Copilot uses AI. Check for mistakes.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants