From a83cf41d6c9524cc91f1c8edec313a16ff1283bf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?M=C3=A5ns=20Nilsson?= Date: Tue, 10 Mar 2026 11:24:54 +0100 Subject: [PATCH 1/2] Arm backend: Update _adjust_weight_qspec_for_conv_transpose MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * FusedMovingAvgObsFakeQuantize only supports channel axis 0, fall back to non-fused (MovingAveragePerChannelMinMaxObserver + FakeQuantize) when applicable. * Always check/update ch_axis for FakeQuantize/observer constructor regardless if ch axis is correct in Qspec. * Add unit tests. Jira: MLCE-1708 Signed-off-by: Måns Nilsson Change-Id: Ie707c3f446bbc23a454d2ce7bbf9c0cd32582e05 --- .../arm/quantizer/quantization_annotator.py | 67 +++++- .../arm/test/ops/test_transpose_conv2d.py | 225 ++++++++++++++++-- 2 files changed, 267 insertions(+), 25 deletions(-) diff --git a/backends/arm/quantizer/quantization_annotator.py b/backends/arm/quantizer/quantization_annotator.py index 77b5965bc9a..d0b4eae8932 100644 --- a/backends/arm/quantizer/quantization_annotator.py +++ b/backends/arm/quantizer/quantization_annotator.py @@ -23,7 +23,12 @@ from torch._subclasses import FakeTensor from torch.fx import Node -from torchao.quantization.pt2e import PartialWrapper +from torchao.quantization.pt2e import ( + FakeQuantize, + FusedMovingAvgObsFakeQuantize, + MovingAveragePerChannelMinMaxObserver, + PartialWrapper, +) from torchao.quantization.pt2e.quantizer import ( annotate_input_qspec_map, annotate_output_qspec, @@ -41,6 +46,12 @@ logger = logging.getLogger(__name__) +def _is_fused_moving_avg_obs_fake_quant_ctor(func: object) -> bool: + """Return True when ``func`` is the fused fake-quant class or a subclass.""" + + return isinstance(func, type) and issubclass(func, FusedMovingAvgObsFakeQuantize) + + @dataclass(frozen=True) class _QuantProperty: """Specify how the input/output at 'index' must be quantized.""" @@ -85,10 +96,29 @@ def _as_list(x): ] -def _adjust_weight_qspec_for_conv_transpose(node: Node, weight_qspec): +def _adjust_weight_qspec_for_conv_transpose( + node: Node, weight_qspec: QuantizationSpec +) -> QuantizationSpec: + """Adjust weight qspec axis/ctor for conv_transpose2d per-channel + quantization. + + Use axis 1 for ungrouped ConvTranspose2d weights because the weight layout is + (in_channels, out_channels / groups, kH, kW). Grouped transpose conv keeps axis 0. + + If the weight qspec contains a TorchAO QAT fake-quant/observer constructor + (e.g. PartialWrapper(partial(...)) or a with_args-based constructor), the + constructor is rebuilt with the corrected axis. For fused per-channel + FakeQuantize, which only supports axis 0, the constructor is replaced with + a non-fused FakeQuantize + MovingAveragePerChannelMinMaxObserver when the + required axis is not 0. + + """ + assert isinstance( + weight_qspec, QuantizationSpec + ), f"Expected QuantizationSpec, got {type(weight_qspec)}" + if ( node.target != torch.ops.aten.conv_transpose2d.input - or not isinstance(weight_qspec, QuantizationSpec) or weight_qspec.qscheme != torch.per_channel_symmetric ): return weight_qspec @@ -101,27 +131,42 @@ def _adjust_weight_qspec_for_conv_transpose(node: Node, weight_qspec): if len(node.args) > 6 and isinstance(node.args[6], int): groups = node.args[6] expected_axis = 0 if groups != 1 else 1 - if weight_qspec.ch_axis == expected_axis: - return weight_qspec observer_or_fake_quant_ctr = weight_qspec.observer_or_fake_quant_ctr - # TorchAO PT2e QAT commonly represents the ctor as PartialWrapper(partial(...)). - # Rebuild it to update ch_axis while preserving callable_args. + observer_or_fake_quant_ctr_changed = False + # QAT FakeQuantize uses PartialWrapper; rebuild its partial to update ch_axis + # without breaking TorchAO introspection. if isinstance(observer_or_fake_quant_ctr, PartialWrapper): original_callable_args = dict(observer_or_fake_quant_ctr.callable_args) base_partial = observer_or_fake_quant_ctr.p if isinstance(base_partial, functools.partial): base_keywords = dict(base_partial.keywords or {}) base_keywords["ch_axis"] = expected_axis - observer_or_fake_quant_ctr = PartialWrapper( - functools.partial(base_partial.func, **base_keywords) - ) + if ( + _is_fused_moving_avg_obs_fake_quant_ctor(base_partial.func) + and expected_axis != 0 + ): + # Fused per-channel FakeQuant only supports axis 0; for other axes, + # fall back to FakeQuantize with a per-channel observer. + base_keywords["observer"] = MovingAveragePerChannelMinMaxObserver + observer_or_fake_quant_ctr = PartialWrapper( + functools.partial(FakeQuantize, **base_keywords) + ) + else: + observer_or_fake_quant_ctr = PartialWrapper( + functools.partial(base_partial.func, **base_keywords) + ) observer_or_fake_quant_ctr.callable_args = original_callable_args - # Non-QAT observer/fake-quant constructors can be updated via with_args. + observer_or_fake_quant_ctr_changed = True + # Non-QAT observer/fake-quant ctrs can be updated via with_args. elif hasattr(observer_or_fake_quant_ctr, "with_args"): observer_or_fake_quant_ctr = observer_or_fake_quant_ctr.with_args( ch_axis=expected_axis ) + observer_or_fake_quant_ctr_changed = True + + if weight_qspec.ch_axis == expected_axis and not observer_or_fake_quant_ctr_changed: + return weight_qspec return QuantizationSpec( dtype=weight_qspec.dtype, diff --git a/backends/arm/test/ops/test_transpose_conv2d.py b/backends/arm/test/ops/test_transpose_conv2d.py index 0c1b86509c0..56123a7ddf6 100644 --- a/backends/arm/test/ops/test_transpose_conv2d.py +++ b/backends/arm/test/ops/test_transpose_conv2d.py @@ -8,6 +8,7 @@ import conftest import torch +from executorch.backends.arm.quantizer import QuantizationConfig from executorch.backends.arm.quantizer.arm_quantizer import ( get_symmetric_a16w8_quantization_config, get_symmetric_a8w4_quantization_config, @@ -26,6 +27,14 @@ ) from executorch.backends.arm.tosa.specification import TosaSpecification from executorch.backends.test.harness.stages.quantize import Quantize +from torchao.quantization.pt2e import ( + FakeQuantize, + FusedMovingAvgObsFakeQuantize, + MovingAverageMinMaxObserver, + MovingAveragePerChannelMinMaxObserver, +) +from torchao.quantization.pt2e.quantize_pt2e import prepare_pt2e, prepare_qat_pt2e +from torchao.quantization.pt2e.quantizer import QuantizationSpec aten_op = "torch.ops.aten.conv_transpose2d.input" exir_op = "executorch_exir_dialects_edge__ops_aten_convolution_default" # No edge transpoe conv @@ -112,21 +121,52 @@ def forward(self, x): if k.startswith("grouped,") and k.endswith("per_channel_quant=True") } -test_data_QAT = { - "qat_basic": lambda: ( - TransposeConv2d( - in_channels=16, - out_channels=4, - kernel_size=4, - stride=2, - padding=1, - groups=1, - ), - True, - True, +test_data_QAT_MODEL = { + "qat_basic": lambda: TransposeConv2d( + in_channels=16, + out_channels=4, + kernel_size=4, + stride=2, + padding=1, + groups=1, + ), + "non_grouped": lambda: TransposeConv2d( + in_channels=12, + out_channels=16, + kernel_size=3, + stride=1, + padding=1, + groups=1, + ), + "grouped": lambda: TransposeConv2d( + in_channels=4, + out_channels=6, + kernel_size=3, + stride=1, + padding=1, + groups=2, ), } + +def _get_per_channel_fake_quants(module: torch.nn.Module): + result = [] + for mod in module.modules(): + if isinstance(mod, (FakeQuantize, FusedMovingAvgObsFakeQuantize)): + observer = getattr(mod, "activation_post_process", None) + if observer is not None and hasattr(observer, "ch_axis"): + result.append((mod, observer)) + return result + + +def _get_per_channel_observers(module: torch.nn.Module): + result = [] + for mod in module.modules(): + if isinstance(mod, MovingAveragePerChannelMinMaxObserver): + result.append(mod) + return result + + u55_supported_test_data_INT = { k: v for k, v in test_data_INT.items() @@ -191,10 +231,12 @@ def test_conv_transpose2d_tosa_INT(test_data): pipeline.run() -@common.parametrize("test_data", test_data_QAT) +@common.parametrize("test_data", {"qat_basic": test_data_QAT_MODEL["qat_basic"]}) def test_conv_transpose2d_tosa_INT_qat_per_channel_quantization_pipeline(test_data): - model, is_per_channel, is_qat = test_data() + model = test_data() inputs = model.get_inputs() + is_per_channel = True + is_qat = True quantizer = TOSAQuantizer(TosaSpecification.create_from_string("TOSA-1.0+INT")) quantizer.set_global( get_symmetric_quantization_config( @@ -214,6 +256,161 @@ def test_conv_transpose2d_tosa_INT_qat_per_channel_quantization_pipeline(test_da pipeline.run() +@common.parametrize("test_data", {"non_grouped": test_data_QAT_MODEL["non_grouped"]}) +def test_conv_transpose2d_tosa_INT_qat_axis1_uses_non_fused_fake_quant(test_data): + model = test_data() + inputs = model.get_inputs() + quantizer = TOSAQuantizer(TosaSpecification.create_from_string("TOSA-1.0+INT")) + + activation_qspec = QuantizationSpec( + dtype=torch.int8, + qscheme=torch.per_tensor_affine, + observer_or_fake_quant_ctr=FusedMovingAvgObsFakeQuantize.with_args( + observer=MovingAverageMinMaxObserver + ), + ) + weight_qspec = QuantizationSpec( + dtype=torch.int8, + qscheme=torch.per_channel_symmetric, + ch_axis=0, + observer_or_fake_quant_ctr=FusedMovingAvgObsFakeQuantize.with_args( + observer=MovingAveragePerChannelMinMaxObserver, ch_axis=0 + ), + ) + quantizer.set_global( + QuantizationConfig( + input_activation=activation_qspec, + output_activation=activation_qspec, + weight=weight_qspec, + bias=None, + ) + ) + + prepared = prepare_qat_pt2e( + torch.export.export(model, inputs, strict=True).module(), quantizer + ) + per_channel_fqs = _get_per_channel_fake_quants(prepared) + assert per_channel_fqs + assert all(isinstance(mod, FakeQuantize) for mod, _ in per_channel_fqs) + assert all(obs.ch_axis == 1 for _, obs in per_channel_fqs) + + +@common.parametrize("test_data", {"axis0_grouped": test_data_QAT_MODEL["grouped"]}) +def test_conv_transpose2d_tosa_INT_grouped_qat_axis0_keeps_fused_fake_quant(test_data): + model = test_data() + inputs = model.get_inputs() + quantizer = TOSAQuantizer(TosaSpecification.create_from_string("TOSA-1.0+INT")) + + activation_qspec = QuantizationSpec( + dtype=torch.int8, + qscheme=torch.per_tensor_affine, + observer_or_fake_quant_ctr=FusedMovingAvgObsFakeQuantize.with_args( + observer=MovingAverageMinMaxObserver + ), + ) + weight_qspec = QuantizationSpec( + dtype=torch.int8, + qscheme=torch.per_channel_symmetric, + ch_axis=0, + observer_or_fake_quant_ctr=FusedMovingAvgObsFakeQuantize.with_args( + observer=MovingAveragePerChannelMinMaxObserver, ch_axis=0 + ), + ) + quantizer.set_global( + QuantizationConfig( + input_activation=activation_qspec, + output_activation=activation_qspec, + weight=weight_qspec, + bias=None, + ) + ) + + prepared = prepare_qat_pt2e( + torch.export.export(model, inputs, strict=True).module(), quantizer + ) + per_channel_fqs = _get_per_channel_fake_quants(prepared) + assert per_channel_fqs + assert all( + isinstance(mod, FusedMovingAvgObsFakeQuantize) for mod, _ in per_channel_fqs + ) + assert all(obs.ch_axis == 0 for _, obs in per_channel_fqs) + + +@common.parametrize("test_data", {"non_grouped": test_data_QAT_MODEL["non_grouped"]}) +def test_conv_transpose2d_tosa_INT_ptq_observer_updates_axis(test_data): + model = test_data() + inputs = model.get_inputs() + quantizer = TOSAQuantizer(TosaSpecification.create_from_string("TOSA-1.0+INT")) + + activation_qspec = QuantizationSpec( + dtype=torch.int8, + qscheme=torch.per_tensor_affine, + observer_or_fake_quant_ctr=MovingAverageMinMaxObserver.with_args(), + ) + weight_qspec = QuantizationSpec( + dtype=torch.int8, + qscheme=torch.per_channel_symmetric, + ch_axis=0, + observer_or_fake_quant_ctr=MovingAveragePerChannelMinMaxObserver.with_args( + ch_axis=0 + ), + ) + quantizer.set_global( + QuantizationConfig( + input_activation=activation_qspec, + output_activation=activation_qspec, + weight=weight_qspec, + bias=None, + ) + ) + + prepared = prepare_pt2e( + torch.export.export(model, inputs, strict=True).module(), quantizer + ) + per_channel_obs = _get_per_channel_observers(prepared) + assert per_channel_obs + assert all(obs.ch_axis == 1 for obs in per_channel_obs) + + +@common.parametrize("test_data", {"non_grouped": test_data_QAT_MODEL["non_grouped"]}) +def test_conv_transpose2d_tosa_INT_qat_correct_qspec_wrong_ctor_axis(test_data): + model = test_data() + inputs = model.get_inputs() + quantizer = TOSAQuantizer(TosaSpecification.create_from_string("TOSA-1.0+INT")) + + activation_qspec = QuantizationSpec( + dtype=torch.int8, + qscheme=torch.per_tensor_affine, + observer_or_fake_quant_ctr=FusedMovingAvgObsFakeQuantize.with_args( + observer=MovingAverageMinMaxObserver + ), + ) + weight_qspec = QuantizationSpec( + dtype=torch.int8, + qscheme=torch.per_channel_symmetric, + ch_axis=1, + observer_or_fake_quant_ctr=FusedMovingAvgObsFakeQuantize.with_args( + observer=MovingAveragePerChannelMinMaxObserver, ch_axis=0 + ), + ) + quantizer.set_global( + QuantizationConfig( + input_activation=activation_qspec, + output_activation=activation_qspec, + weight=weight_qspec, + bias=None, + ) + ) + + prepared = prepare_qat_pt2e( + torch.export.export(model, inputs, strict=True).module(), quantizer + ) + per_channel_fqs = _get_per_channel_fake_quants(prepared) + assert per_channel_fqs + assert all(isinstance(mod, FakeQuantize) for mod, _ in per_channel_fqs) + assert all(obs.ch_axis == 1 for _, obs in per_channel_fqs) + + _a8w4_transpose_conv_xfails = { k: "per-channel int4 weight quantization is not supported for transpose conv yet." for k in test_data_INT From b2e60406b0583bf529b47f1c23da00cb722429c2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?M=C3=A5ns=20Nilsson?= Date: Wed, 11 Mar 2026 10:28:40 +0100 Subject: [PATCH 2/2] Arm backend: Fix review comment MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Måns Nilsson Change-Id: I592b4326001baba0eebf7951ef3f625f934bc9c9 --- .../arm/quantizer/quantization_annotator.py | 26 ++++++++++++------- 1 file changed, 17 insertions(+), 9 deletions(-) diff --git a/backends/arm/quantizer/quantization_annotator.py b/backends/arm/quantizer/quantization_annotator.py index d0b4eae8932..45c66301d73 100644 --- a/backends/arm/quantizer/quantization_annotator.py +++ b/backends/arm/quantizer/quantization_annotator.py @@ -97,8 +97,8 @@ def _as_list(x): def _adjust_weight_qspec_for_conv_transpose( - node: Node, weight_qspec: QuantizationSpec -) -> QuantizationSpec: + node: Node, weight_qspec: QuantizationSpec | None +) -> QuantizationSpec | None: """Adjust weight qspec axis/ctor for conv_transpose2d per-channel quantization. @@ -112,13 +112,13 @@ def _adjust_weight_qspec_for_conv_transpose( a non-fused FakeQuantize + MovingAveragePerChannelMinMaxObserver when the required axis is not 0. + Return the qspec unchanged when weights are unset. + """ - assert isinstance( - weight_qspec, QuantizationSpec - ), f"Expected QuantizationSpec, got {type(weight_qspec)}" if ( node.target != torch.ops.aten.conv_transpose2d.input + or weight_qspec is None or weight_qspec.qscheme != torch.per_channel_symmetric ): return weight_qspec @@ -626,9 +626,10 @@ def any_or_hardtanh_min_zero(n: Node): filter_fn=any_or_hardtanh_min_zero, ): if node.target in _conv_ops: + conv_weight_qspec = ensure_type(QuantizationSpec, weight_qspec) # For MyPy quant_properties.quant_inputs = [ _QuantProperty(0, input_act_qspec), - _QuantProperty(1, weight_qspec, mark_annotated=True), + _QuantProperty(1, conv_weight_qspec, mark_annotated=True), _QuantProperty(2, bias_qspec, optional=True, mark_annotated=True), ] elif node.target in ( @@ -647,9 +648,10 @@ def any_or_hardtanh_min_zero(n: Node): ], ): if node.target in _conv_ops: + conv_weight_qspec = ensure_type(QuantizationSpec, weight_qspec) # For MyPy quant_properties.quant_inputs = [ _QuantProperty(0, input_act_qspec), - _QuantProperty(1, weight_qspec, mark_annotated=True), + _QuantProperty(1, conv_weight_qspec, mark_annotated=True), _QuantProperty(2, bias_qspec, optional=True, mark_annotated=True), ] elif node.target in [ @@ -676,9 +678,12 @@ def any_or_hardtanh_min_zero(n: Node): *_conv_ops, torch.ops.aten.linear.default, ): + conv_or_linear_weight_qspec = ensure_type( + QuantizationSpec, weight_qspec + ) # For MyPy quant_properties.quant_inputs = [ _QuantProperty(0, input_act_qspec), - _QuantProperty(1, weight_qspec, mark_annotated=True), + _QuantProperty(1, conv_or_linear_weight_qspec, mark_annotated=True), _QuantProperty(2, bias_qspec, optional=True, mark_annotated=True), ] else: @@ -687,9 +692,12 @@ def any_or_hardtanh_min_zero(n: Node): *_conv_ops, torch.ops.aten.linear.default, ): + conv_or_linear_weight_qspec = ensure_type( + QuantizationSpec, weight_qspec + ) # For MyPy quant_properties.quant_inputs = [ _QuantProperty(0, input_act_qspec), - _QuantProperty(1, weight_qspec, mark_annotated=True), + _QuantProperty(1, conv_or_linear_weight_qspec, mark_annotated=True), _QuantProperty(2, bias_qspec, optional=True, mark_annotated=True), ] quant_properties.quant_output = _QuantProperty(0, output_act_qspec)