From a16d94d6d431918f80a14cf3d2d8f383eb324cc0 Mon Sep 17 00:00:00 2001 From: Samaresh Kumar Singh Date: Wed, 13 May 2026 14:56:56 -0500 Subject: [PATCH] Set exclude_outside=1 on Resize for antialiased upsample ONNX Resize defaults exclude_outside=0 while PyTorch's antialiased upsample kernels drop out of bounds samples from the filter window and renormalize the remaining weights so they sum to one. Without exclude_outside=1 the exported model produces values that differ from eager near the boundary. This also drops the shape only test guard for the AA upsample ops so the existing numerical comparison runs again. Fixes pytorch/pytorch#183546. --- onnxscript/function_libs/torch_lib/ops/nn.py | 14 ++++++++++++++ tests/function_libs/torch_lib/ops_test_data.py | 14 -------------- 2 files changed, 14 insertions(+), 14 deletions(-) diff --git a/onnxscript/function_libs/torch_lib/ops/nn.py b/onnxscript/function_libs/torch_lib/ops/nn.py index de89ff6bad..c2cf8fd49d 100644 --- a/onnxscript/function_libs/torch_lib/ops/nn.py +++ b/onnxscript/function_libs/torch_lib/ops/nn.py @@ -2330,6 +2330,7 @@ def _aten_upsample_output_size( coordinate_transformation_mode: str, antialias: int = 0, cubic_coeff_a: float = -0.75, + exclude_outside: int = 0, ) -> TReal: batch_and_channel = op.Shape(self, end=2, start=0) # When output_size is passed in as a list of integers, the torch.onnx @@ -2348,6 +2349,7 @@ def _aten_upsample_output_size( cubic_coeff_a=cubic_coeff_a, nearest_mode="floor", antialias=antialias, + exclude_outside=exclude_outside, ) @@ -2358,6 +2360,7 @@ def _aten_upsample_scales( coordinate_transformation_mode: str, antialias: int = 0, cubic_coeff_a: float = -0.75, + exclude_outside: int = 0, ) -> TReal: return op.Resize( self, @@ -2371,6 +2374,7 @@ def _aten_upsample_scales( cubic_coeff_a=cubic_coeff_a, nearest_mode="floor", antialias=antialias, + exclude_outside=exclude_outside, ) @@ -2410,6 +2414,10 @@ def aten__upsample_bicubic2d_aa( coordinate_transformation_mode = _get_upsample_align_corners_mode(align_corners) # PyTorch uses cubic_coeff_a=-0.5 (Keys interpolation, PIL-compatible) when # antialias=True, as opposed to -0.75 (OpenCV-compatible) for the non-antialias case. + # exclude_outside=1 matches PyTorch's antialias kernel, which drops out-of-bounds + # samples from the filter window and renormalizes the remaining weights so they + # sum to 1. Without it the ONNX Resize keeps phantom out-of-bounds weight in the + # denominator and produces values that differ from eager near the boundary. return _aten_upsample_output_size( self, output_size, @@ -2417,6 +2425,7 @@ def aten__upsample_bicubic2d_aa( coordinate_transformation_mode=coordinate_transformation_mode, antialias=1, cubic_coeff_a=-0.5, + exclude_outside=1, ) @@ -2495,12 +2504,17 @@ def aten__upsample_bilinear2d_aa( # NOTE: Based on experimentation, scales_h and scales_w are always ignored in PyTorch, # unless when align_corners is True, in which case we do not know what is going on. coordinate_transformation_mode = _get_upsample_align_corners_mode(align_corners) + # exclude_outside=1 matches PyTorch's antialias kernel, which drops out-of-bounds + # samples from the filter window and renormalizes the remaining weights so they + # sum to 1. Without it the ONNX Resize keeps phantom out-of-bounds weight in the + # denominator and produces values that differ from eager near the boundary. return _aten_upsample_output_size( self, output_size, coordinate_transformation_mode=coordinate_transformation_mode, mode="linear", antialias=1, + exclude_outside=1, ) diff --git a/tests/function_libs/torch_lib/ops_test_data.py b/tests/function_libs/torch_lib/ops_test_data.py index 4bf98f587d..f4af54812e 100644 --- a/tests/function_libs/torch_lib/ops_test_data.py +++ b/tests/function_libs/torch_lib/ops_test_data.py @@ -1784,13 +1784,6 @@ def _where_input_wrangler( TorchLibOpInfo( "ops.aten._upsample_bilinear2d_aa", nn_ops.aten__upsample_bilinear2d_aa, - # ONNX and PyTorch use different anti-aliasing algorithms, so numerical results differ. - # However, the implementation is verified correct because: - # 1. The function correctly passes antialias=1 to ONNX Resize operation - # 2. Shape validation ensures the operation works correctly - # 3. Additional validation in test_aa_upsample_validation.py confirms correctness - # Shape-only comparison is the appropriate testing approach for this case. - compare_shape_only_for_output=(0,), ), TorchLibOpInfo("ops.aten.upsample_bilinear2d.vec", nn_ops.aten_upsample_bilinear2d_vec), TorchLibOpInfo( @@ -1805,13 +1798,6 @@ def _where_input_wrangler( TorchLibOpInfo( "ops.aten._upsample_bicubic2d_aa", nn_ops.aten__upsample_bicubic2d_aa, - # ONNX and PyTorch use different anti-aliasing algorithms, so numerical results differ. - # However, the implementation is verified correct because: - # 1. The function correctly passes antialias=1 to ONNX Resize operation - # 2. Shape validation ensures the operation works correctly - # 3. Additional validation in test_aa_upsample_validation.py confirms correctness - # Shape-only comparison is the appropriate testing approach for this case. - compare_shape_only_for_output=(0,), ), TorchLibOpInfo("ops.aten.upsample_bicubic2d.vec", nn_ops.aten_upsample_bicubic2d_vec), TorchLibOpInfo(