Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 68 additions & 15 deletions backends/arm/quantizer/quantization_annotator.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,12 @@
from torch._subclasses import FakeTensor

from torch.fx import Node
from torchao.quantization.pt2e import PartialWrapper
from torchao.quantization.pt2e import (
FakeQuantize,
FusedMovingAvgObsFakeQuantize,
MovingAveragePerChannelMinMaxObserver,
PartialWrapper,
)
from torchao.quantization.pt2e.quantizer import (
annotate_input_qspec_map,
annotate_output_qspec,
Expand All @@ -41,6 +46,12 @@
logger = logging.getLogger(__name__)


def _is_fused_moving_avg_obs_fake_quant_ctor(func: object) -> bool:
"""Return True when ``func`` is the fused fake-quant class or a subclass."""

return isinstance(func, type) and issubclass(func, FusedMovingAvgObsFakeQuantize)


@dataclass(frozen=True)
class _QuantProperty:
"""Specify how the input/output at 'index' must be quantized."""
Expand Down Expand Up @@ -85,10 +96,29 @@ def _as_list(x):
]


def _adjust_weight_qspec_for_conv_transpose(node: Node, weight_qspec):
def _adjust_weight_qspec_for_conv_transpose(
node: Node, weight_qspec: QuantizationSpec | None
) -> QuantizationSpec | None:
"""Adjust weight qspec axis/ctor for conv_transpose2d per-channel
quantization.

Use axis 1 for ungrouped ConvTranspose2d weights because the weight layout is
(in_channels, out_channels / groups, kH, kW). Grouped transpose conv keeps axis 0.

If the weight qspec contains a TorchAO QAT fake-quant/observer constructor
(e.g. PartialWrapper(partial(...)) or a with_args-based constructor), the
constructor is rebuilt with the corrected axis. For fused per-channel
FakeQuantize, which only supports axis 0, the constructor is replaced with
a non-fused FakeQuantize + MovingAveragePerChannelMinMaxObserver when the
required axis is not 0.

Return the qspec unchanged when weights are unset.

"""

if (
node.target != torch.ops.aten.conv_transpose2d.input
or not isinstance(weight_qspec, QuantizationSpec)
or weight_qspec is None
or weight_qspec.qscheme != torch.per_channel_symmetric
):
return weight_qspec
Expand All @@ -101,27 +131,42 @@ def _adjust_weight_qspec_for_conv_transpose(node: Node, weight_qspec):
if len(node.args) > 6 and isinstance(node.args[6], int):
groups = node.args[6]
expected_axis = 0 if groups != 1 else 1
if weight_qspec.ch_axis == expected_axis:
return weight_qspec

observer_or_fake_quant_ctr = weight_qspec.observer_or_fake_quant_ctr
# TorchAO PT2e QAT commonly represents the ctor as PartialWrapper(partial(...)).
# Rebuild it to update ch_axis while preserving callable_args.
observer_or_fake_quant_ctr_changed = False
# QAT FakeQuantize uses PartialWrapper; rebuild its partial to update ch_axis
# without breaking TorchAO introspection.
if isinstance(observer_or_fake_quant_ctr, PartialWrapper):
original_callable_args = dict(observer_or_fake_quant_ctr.callable_args)
base_partial = observer_or_fake_quant_ctr.p
if isinstance(base_partial, functools.partial):
base_keywords = dict(base_partial.keywords or {})
base_keywords["ch_axis"] = expected_axis
observer_or_fake_quant_ctr = PartialWrapper(
functools.partial(base_partial.func, **base_keywords)
)
if (
_is_fused_moving_avg_obs_fake_quant_ctor(base_partial.func)
and expected_axis != 0
):
# Fused per-channel FakeQuant only supports axis 0; for other axes,
# fall back to FakeQuantize with a per-channel observer.
base_keywords["observer"] = MovingAveragePerChannelMinMaxObserver
observer_or_fake_quant_ctr = PartialWrapper(
functools.partial(FakeQuantize, **base_keywords)
)
else:
observer_or_fake_quant_ctr = PartialWrapper(
functools.partial(base_partial.func, **base_keywords)
)
observer_or_fake_quant_ctr.callable_args = original_callable_args
# Non-QAT observer/fake-quant constructors can be updated via with_args.
observer_or_fake_quant_ctr_changed = True
# Non-QAT observer/fake-quant ctrs can be updated via with_args.
elif hasattr(observer_or_fake_quant_ctr, "with_args"):
observer_or_fake_quant_ctr = observer_or_fake_quant_ctr.with_args(
ch_axis=expected_axis
)
observer_or_fake_quant_ctr_changed = True

if weight_qspec.ch_axis == expected_axis and not observer_or_fake_quant_ctr_changed:
return weight_qspec

return QuantizationSpec(
dtype=weight_qspec.dtype,
Expand Down Expand Up @@ -581,9 +626,10 @@ def any_or_hardtanh_min_zero(n: Node):
filter_fn=any_or_hardtanh_min_zero,
):
if node.target in _conv_ops:
conv_weight_qspec = ensure_type(QuantizationSpec, weight_qspec) # For MyPy
quant_properties.quant_inputs = [
_QuantProperty(0, input_act_qspec),
_QuantProperty(1, weight_qspec, mark_annotated=True),
_QuantProperty(1, conv_weight_qspec, mark_annotated=True),
_QuantProperty(2, bias_qspec, optional=True, mark_annotated=True),
]
elif node.target in (
Expand All @@ -602,9 +648,10 @@ def any_or_hardtanh_min_zero(n: Node):
],
):
if node.target in _conv_ops:
conv_weight_qspec = ensure_type(QuantizationSpec, weight_qspec) # For MyPy
quant_properties.quant_inputs = [
_QuantProperty(0, input_act_qspec),
_QuantProperty(1, weight_qspec, mark_annotated=True),
_QuantProperty(1, conv_weight_qspec, mark_annotated=True),
_QuantProperty(2, bias_qspec, optional=True, mark_annotated=True),
]
elif node.target in [
Expand All @@ -631,9 +678,12 @@ def any_or_hardtanh_min_zero(n: Node):
*_conv_ops,
torch.ops.aten.linear.default,
):
conv_or_linear_weight_qspec = ensure_type(
QuantizationSpec, weight_qspec
) # For MyPy
quant_properties.quant_inputs = [
_QuantProperty(0, input_act_qspec),
_QuantProperty(1, weight_qspec, mark_annotated=True),
_QuantProperty(1, conv_or_linear_weight_qspec, mark_annotated=True),
_QuantProperty(2, bias_qspec, optional=True, mark_annotated=True),
]
else:
Expand All @@ -642,9 +692,12 @@ def any_or_hardtanh_min_zero(n: Node):
*_conv_ops,
torch.ops.aten.linear.default,
):
conv_or_linear_weight_qspec = ensure_type(
QuantizationSpec, weight_qspec
) # For MyPy
quant_properties.quant_inputs = [
_QuantProperty(0, input_act_qspec),
_QuantProperty(1, weight_qspec, mark_annotated=True),
_QuantProperty(1, conv_or_linear_weight_qspec, mark_annotated=True),
_QuantProperty(2, bias_qspec, optional=True, mark_annotated=True),
]
quant_properties.quant_output = _QuantProperty(0, output_act_qspec)
Expand Down
Loading
Loading