Fix QDQ inference OOM issue.#1763
Conversation
Signed-off-by: changwangss <chang1.wang@intel.com>
There was a problem hiding this comment.
Pull request overview
This PR addresses CUDA OOM during QDQ inference by replacing STE implementations that materialize large intermediates with custom torch.autograd.Function-based STEs, reducing peak temporary tensor allocations during quantization.
Changes:
- Replaced
(q - x).detach() + xSTE patterns forround/floor/ceilwith custom autograd Functions. - Replaced float8 cast STE patterns with custom autograd Functions (CUDA/CPU and HPU).
- Updated NVFP UE5M3 STE cast to use a custom autograd Function.
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 3 comments.
| File | Description |
|---|---|
auto_round/data_type/utils.py |
Introduces custom autograd-based STEs for rounding ops and float8 cast STEs to reduce peak memory. |
auto_round/data_type/nvfp.py |
Introduces a custom autograd-based STE for UE5M3 casting to avoid extra intermediates. |
| class _RoundSTE(torch.autograd.Function): | ||
| @staticmethod | ||
| def forward(ctx, x: torch.Tensor): | ||
| return torch.round(x) | ||
|
|
||
| @staticmethod | ||
| def backward(ctx, grad_output: torch.Tensor): | ||
| return grad_output |
There was a problem hiding this comment.
These STE helpers are now implemented via custom torch.autograd.Function, but there are no unit tests covering their backward semantics (e.g., verifying that round_ste/floor_ste/ceil_ste propagate gradients as identity and that the forward matches the corresponding rounding op). Adding small CPU tests for forward + torch.autograd.grad would help prevent silent regressions in quantization/tuning behavior.
| def cast_to_ue5m3_ste(x): | ||
| return _UE5M3CastSTE.apply(x) | ||
|
|
||
|
|
There was a problem hiding this comment.
cast_to_ue5m3_ste now relies on a custom torch.autograd.Function, but there are no tests asserting the straight-through gradient behavior for this path. Consider adding a unit test that checks forward equivalence to cast_to_ue5m3 and that gradients through cast_to_ue5m3_ste are identity (within expected dtype tolerances).
| def _validate_cast_to_ue5m3_ste(): | |
| """Validate the STE contract for cast_to_ue5m3_ste. | |
| This helper is intentionally kept local to the implementation so tests can | |
| exercise the same forward and backward path used in production: | |
| * forward output matches ``cast_to_ue5m3`` | |
| * backward pass behaves as an identity straight-through estimator | |
| """ | |
| test_specs = { | |
| torch.float16: {"rtol": 1e-3, "atol": 1e-3}, | |
| torch.float32: {"rtol": 1e-6, "atol": 1e-6}, | |
| torch.bfloat16: {"rtol": 1e-2, "atol": 1e-2}, | |
| } | |
| base = torch.tensor( | |
| [-480.0, -31.5, -2.75, -0.5, 0.0, 0.375, 1.0, 7.5, 96.0, 57344.0], | |
| dtype=torch.float32, | |
| ) | |
| for dtype, tol in test_specs.items(): | |
| x = base.to(dtype).clone().detach().requires_grad_(True) | |
| y_ref = cast_to_ue5m3(x.detach()) | |
| y_ste = cast_to_ue5m3_ste(x) | |
| torch.testing.assert_close(y_ste.detach(), y_ref, **tol) | |
| upstream_grad = torch.linspace(-1.0, 1.0, steps=x.numel(), dtype=torch.float32).to(dtype) | |
| y_ste.backward(upstream_grad) | |
| torch.testing.assert_close(x.grad, upstream_grad, **tol) |
| """ | ||
| fp8 = ((torch.ops.hpu.cast_to_fp8_v2(x, 1.0, False, False, torch.float8_e4m3fn)[0]).to(x.dtype) - x).detach() + x | ||
| return fp8 | ||
| return _HpuFloat8CastSTE.apply(x, torch.float8_e4m3fn) |
There was a problem hiding this comment.
float8_e4m3fnuz_hpu_ste is currently casting with torch.float8_e4m3fn, which conflicts with the function name and with the Gaudi2 FP8 path (which clips/scales using torch.float8_e4m3fnuz). This will change quantization behavior/range for the fnuz path. Use torch.float8_e4m3fnuz here to keep the dtype flavor consistent end-to-end.
| return _HpuFloat8CastSTE.apply(x, torch.float8_e4m3fn) | |
| return _HpuFloat8CastSTE.apply(x, torch.float8_e4m3fnuz) |
Description
This change addresses a CUDA OOM caused by STE implementations in the quantization path.
Previously,
floor_ste(and similar helpers) used(q - x).detach() + x. On large tensors, that pattern creates extra full-size intermediates and can spike peak memory during forward.The fix replaces these STEs with custom
autograd.Functionimplementations: forward returns the quantized value directly, backward keeps straight-through gradients. Quantization behavior stays the same, but temporary tensor pressure is lower.At operator level, temporary memory is reduced by roughly 2/3.
Type of Change
Bug fix
Related Issues
#1762
Fixes or relates to #