diff --git a/backends/arm/quantizer/TARGETS b/backends/arm/quantizer/TARGETS index 28bfe15b528..22f65a03cf1 100644 --- a/backends/arm/quantizer/TARGETS +++ b/backends/arm/quantizer/TARGETS @@ -17,9 +17,11 @@ runtime.python_library( deps = [ ":arm_quantizer_utils", ":quantization_annotator", + ":quantizer_support", "//executorch/backends/arm:constants", "//executorch/backends/arm:ethosu", "//executorch/backends/arm:vgf", + "//executorch/backends/cortex_m/quantizer:quantizer", "//executorch/backends/arm/tosa:specification", "//executorch/backends/arm:arm_compile_spec", "//caffe2:torch", @@ -48,6 +50,16 @@ runtime.python_library( ], ) +runtime.python_library( + name = "quantizer_support", + srcs = ["quantizer_support.py"], + deps = [ + ":quantization_annotator", + "//caffe2:torch", + "//executorch/backends/cortex_m/quantizer:quantizer", + ], +) + runtime.python_library( name = "lib", srcs = ["__init__.py"], diff --git a/backends/arm/quantizer/__init__.py b/backends/arm/quantizer/__init__.py index 270d56a68cd..8d1404c7982 100644 --- a/backends/arm/quantizer/__init__.py +++ b/backends/arm/quantizer/__init__.py @@ -11,17 +11,34 @@ """ from .quantization_config import QuantizationConfig # noqa # usort: skip -from .arm_quantizer import ( # noqa - EthosUQuantizer, - get_symmetric_a16w8_quantization_config, - get_symmetric_quantization_config, - TOSAQuantizer, - VgfQuantizer, -) # Used in tests from .arm_quantizer_utils import is_annotated # noqa +# Lazily import heavy quantizer classes to avoid circular imports with +# Cortex-M quantization configs. +_LAZY_EXPORTS = { + "EthosUQuantizer": "executorch.backends.arm.quantizer.arm_quantizer", + "get_symmetric_a16w8_quantization_config": "executorch.backends.arm.quantizer.arm_quantizer", + "get_symmetric_quantization_config": "executorch.backends.arm.quantizer.arm_quantizer", + "TOSAQuantizer": "executorch.backends.arm.quantizer.arm_quantizer", + "VgfQuantizer": "executorch.backends.arm.quantizer.arm_quantizer", +} + + +def __getattr__(name: str): + if name not in _LAZY_EXPORTS: + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") + import importlib + + module = importlib.import_module(_LAZY_EXPORTS[name]) + return getattr(module, name) + + +def __dir__(): + return sorted(list(globals().keys()) + list(_LAZY_EXPORTS.keys())) + + # Load quantized ops library. try: import executorch.extension.pybindings.portable_lib diff --git a/backends/arm/quantizer/arm_quantizer.py b/backends/arm/quantizer/arm_quantizer.py index 2f312179e8f..7f54b4aba9c 100644 --- a/backends/arm/quantizer/arm_quantizer.py +++ b/backends/arm/quantizer/arm_quantizer.py @@ -9,18 +9,53 @@ # # Quantizer for Arm backend # - from __future__ import annotations import functools +import logging from typing import Any, Callable, Dict, List, Optional import torch +from executorch.backends.arm.common.annotation_meta import ArmAnnotationInfo from executorch.backends.arm.constants import DISALLOW_TFA_META_KEY from executorch.backends.arm.ethosu import EthosUCompileSpec - -from executorch.backends.arm.quantizer import QuantizationConfig +from executorch.backends.arm.quantizer.quantization_config import ( + QuantizationConfig, + TOSAQuantizationConfig, +) +from executorch.backends.arm.quantizer.quantizer_support import ( + TOSA_QUANTIZER_SUPPORT_DICT, +) from executorch.backends.arm.tosa import TosaSpecification +from executorch.backends.cortex_m.quantizer.node_finders import ( + GlobalNodeFinder, + InputNodeFinder, + ModuleNameNodeFinder, + ModuleTypeNodeFinder, + NodeFinder, + NodeNameNodeFinder, + NodeTargetNodeFinder, + OutputNodeFinder, +) +from executorch.backends.cortex_m.quantizer.pattern_matcher import PatternMatcher +from executorch.backends.cortex_m.quantizer.quantizer import ( + PatternQuantizer, + SharedQspecQuantizer, +) + +from executorch.backends.cortex_m.quantizer.quantizer_reporter import ( + QuantizerReporter, + SUPPORTED_QCONFIGS, + SUPPORTED_QSPECS, +) +from torch._ops import OpOverload + +from torchao.quantization.pt2e.quantizer import ( + ComposableQuantizer, + QuantizationAnnotation, + Quantizer, +) +from torchao.quantization.pt2e.quantizer.quantizer import Q_ANNOTATION_KEY from executorch.backends.arm.common.arm_compile_spec import ( ArmCompileSpec, ) # isort: skip @@ -30,6 +65,10 @@ ) from executorch.backends.arm.vgf import VgfCompileSpec +from executorch.backends.cortex_m.quantizer.quantization_configs import ( + _get_int32_bias_qspec, + _get_int32_per_channel_bias_qspec, +) from torch.fx import GraphModule, Node from torchao.quantization.pt2e import ( FakeQuantize, @@ -52,12 +91,12 @@ annotate_output_qspec, get_module_name_filter, QuantizationSpec, - Quantizer, ) from .arm_quantizer_utils import is_annotated, mark_node_as_annotated from .quantization_annotator import annotate_graph + __all__ = [ "TOSAQuantizer", "EthosUQuantizer", @@ -66,6 +105,8 @@ "get_symmetric_quantization_config", ] +logger = logging.getLogger(__name__) + @functools.lru_cache def get_symmetric_quantization_config( @@ -79,6 +120,9 @@ def get_symmetric_quantization_config( ) -> QuantizationConfig: """Create symmetric quantization config for activations and weights. + Activations use an affine qscheme; "symmetric" refers to the weight + quantization qscheme. + Args: is_per_channel (bool): Whether to use per-channel quantization for weights. @@ -165,16 +209,20 @@ def get_symmetric_quantization_config( observer_or_fake_quant_ctr=weight_observer_or_fake_quant_ctr, ) - bias_quantization_spec = None + if is_per_channel: + bias_quantization_spec = _get_int32_per_channel_bias_qspec + else: + bias_quantization_spec = _get_int32_bias_qspec + if is_dynamic: - quantization_config = QuantizationConfig( + quantization_config = TOSAQuantizationConfig( act_quantization_spec, None, weight_quantization_spec, bias_quantization_spec, ) else: - quantization_config = QuantizationConfig( + quantization_config = TOSAQuantizationConfig( act_quantization_spec, act_quantization_spec, weight_quantization_spec, @@ -260,22 +308,58 @@ def get_symmetric_a16w8_quantization_config( ) # Replace activation quantization spec with 16-bit version if is_dynamic: - quantization_config = QuantizationConfig( + quantization_config = TOSAQuantizationConfig( act_quantization_spec, # 16-bit input activations None, base_config.weight, # 8-bit weights from base config - None, + base_config.bias, # bias from base config ) else: - quantization_config = QuantizationConfig( + quantization_config = TOSAQuantizationConfig( act_quantization_spec, # 16-bit input activations act_quantization_spec, # 16-bit output activations base_config.weight, # 8-bit weights from base config - None, + base_config.bias, # bias from base config ) return quantization_config +# Register supported quantization configs and qspecs in the reporter for human-readable reporting +# MLETORCH-1854: Temporary solution, refactor to automatically register these instead +_symmetric_a8w4_config_per_channel = get_symmetric_a8w4_quantization_config() +_symmetric_a8w8_config_per_channel = get_symmetric_quantization_config() +_symmetric_a16w8_config_per_channel = get_symmetric_a16w8_quantization_config() +_symmetric_a8w4_config_per_tensor = get_symmetric_a8w4_quantization_config( + is_per_channel=False +) +_symmetric_a8w8_config_per_tensor = get_symmetric_quantization_config( + is_per_channel=False +) +_symmetric_a16w8_config_per_tensor = get_symmetric_a16w8_quantization_config( + is_per_channel=False +) +SUPPORTED_QCONFIGS.update( + { + _symmetric_a8w8_config_per_channel: f"{__name__}.get_symmetric_quantization_config(is_per_channel=True)", + _symmetric_a16w8_config_per_channel: f"{__name__}.get_symmetric_a16w8_quantization_config(is_per_channel=True)", + _symmetric_a8w4_config_per_channel: f"{__name__}.get_symmetric_a8w4_quantization_config(is_per_channel=True)", + _symmetric_a8w8_config_per_tensor: f"{__name__}.get_symmetric_quantization_config(is_per_channel=False)", + _symmetric_a16w8_config_per_tensor: f"{__name__}.get_symmetric_a16w8_quantization_config(is_per_channel=False)", + _symmetric_a8w4_config_per_tensor: f"{__name__}.get_symmetric_a8w4_quantization_config(is_per_channel=False)", + } +) + +SUPPORTED_QSPECS.update( + { + _symmetric_a8w4_config_per_channel.get_weight_qspec(): "INT4_PER_CHANNEL_QSPEC", + _symmetric_a8w8_config_per_channel.get_weight_qspec(): "INT8_PER_CHANNEL_QSPEC", + _symmetric_a8w8_config_per_tensor.get_weight_qspec(): "INT8_PER_TENSOR_QSPEC", + _symmetric_a8w4_config_per_tensor.get_weight_qspec(): "INT4_PER_TENSOR_QSPEC", + _symmetric_a8w8_config_per_tensor.get_input_act_qspec(): "INT8_PER_TENSOR_QSPEC", + _symmetric_a16w8_config_per_tensor.get_input_act_qspec(): "INT16_PER_TENSOR_QSPEC", + } +) + NodeFilterType = Callable[[Node], bool] """Type for a Node Filter used by annotators. @@ -358,41 +442,48 @@ class TOSAQuantizer(Quantizer): """Manage quantization annotations for TOSA-compatible backends.""" def __init__( - self, compile_spec_or_tosa_spec: TosaSpecification | ArmCompileSpec + self, + compile_spec_or_tosa_spec, + use_composable_quantizer: bool = False, ) -> None: - super().__init__() - self.compile_spec: ArmCompileSpec - if isinstance(compile_spec_or_tosa_spec, TosaSpecification): - from executorch.backends.arm.tosa.compile_spec import TosaCompileSpec - - self.compile_spec = TosaCompileSpec(compile_spec_or_tosa_spec) - self.tosa_spec = self.compile_spec.tosa_spec - elif isinstance(compile_spec_or_tosa_spec, ArmCompileSpec): - self.compile_spec = compile_spec_or_tosa_spec - self.tosa_spec = self.compile_spec.tosa_spec + """Create a TOSA quantizer from a TOSA spec or Arm compile spec.""" + self.use_composable_quantizer = use_composable_quantizer + self.quantizer: _TOSAQuantizerV1 | _TOSAQuantizerV2 + if use_composable_quantizer: + logger.info( + "Using composable quantizer implementation in the arm backend. See https://github.com/pytorch/executorch/issues/17701" + ) + self.quantizer = _TOSAQuantizerV2(compile_spec_or_tosa_spec) else: - raise TypeError( - f"TOSAQuantizer constructor expects " - f"a TosaSpecification or compile_spec list, " - f"got {type(compile_spec_or_tosa_spec)}" + logger.info( + "Using default quantizer in the arm backend. This quantizer is planned to be replaced by the composable quantizer implementation in the future, see https://github.com/pytorch/executorch/issues/17701" ) + self.quantizer = _TOSAQuantizerV1(compile_spec_or_tosa_spec) - self.global_config: Optional[QuantizationConfig] = None - self.io_config: Optional[QuantizationConfig] = None - self.module_type_config: Dict[Callable, Optional[QuantizationConfig]] = {} - self.module_name_config: Dict[str, Optional[QuantizationConfig]] = {} + @property + def tosa_spec(self): + return self.quantizer.tosa_spec + + @property + def compile_spec(self): + return self.quantizer.compile_spec + + @property + def global_config(self): + return self.quantizer.global_config def set_global( - self, quantization_config: QuantizationConfig | None + self, quantization_config: Optional[QuantizationConfig] ) -> TOSAQuantizer: """Set quantization_config for submodules not matched by other filters. Args: - quantization_config (QuantizationConfig): Configuration to apply to - modules that are not captured by name or type filters. + quantization_config (Optional[QuantizationConfig]): Configuration to + apply to modules that are not captured by name or type filters. + ``None`` indicates no quantization. """ - self.global_config = quantization_config + self.quantizer.set_global(quantization_config) return self def set_module_type( @@ -400,17 +491,18 @@ def set_module_type( ) -> TOSAQuantizer: """Set quantization_config for submodules with a given module type. - For example, calling set_module_type(Sub) quantizes supported patterns - in each Sub instance with the provided quantization_config. + For example, calling set_module_type(Softmax) quantizes supported + patterns in each Softmax instance with the provided quantization_config. Args: module_type (Callable): Type whose submodules should use the provided quantization configuration. - quantization_config (QuantizationConfig): Configuration to apply to - submodules of the given type. + quantization_config (Optional[QuantizationConfig]): Configuration to + apply to submodules of the given type. ``None`` indicates no + quantization. """ - self.module_type_config[module_type] = quantization_config + self.quantizer.set_module_type(module_type, quantization_config) return self def set_module_name( @@ -423,22 +515,266 @@ def set_module_name( Args: module_name (str): Fully qualified module name to configure. - quantization_config (QuantizationConfig): Configuration applied to - the named submodule. + quantization_config (Optional[QuantizationConfig]): Configuration + applied to the named submodule. ``None`` indicates no + quantization. """ - # Validate that quantization_config is provided - self.module_name_config[module_name] = quantization_config + self.quantizer.set_module_name(module_name, quantization_config) return self - def set_io(self, quantization_config: QuantizationConfig) -> TOSAQuantizer: + def set_io( + self, quantization_config: Optional[QuantizationConfig] + ) -> TOSAQuantizer: """Set quantization_config for input and output nodes. Args: - quantization_config (QuantizationConfig): Configuration describing - activation quantization for model inputs and outputs. + quantization_config (Optional[QuantizationConfig]): Configuration + describing activation quantization for model inputs and outputs. + ``None`` indicates no quantization. + + """ + self.quantizer.set_io(quantization_config) + return self + + def add_quantizer(self, quantizer: Quantizer) -> TOSAQuantizer: + """Insert a quantizer with highest precedence.""" + if self.use_composable_quantizer: + return self.quantizer.add_quantizer(quantizer) # type: ignore[union-attr,return-value] + raise NotImplementedError( + "add_quantizer is only supported in the composable quantizer implementation." + ) + + def set_node_finder( + self, quantization_config: Optional[QuantizationConfig], node_finder: NodeFinder + ) -> TOSAQuantizer: + """Set quantization_config for nodes matched by a custom NodeFinder. + + Args: + quantization_config (Optional[QuantizationConfig]): Configuration + describing quantization settings for nodes matched by the provided + NodeFinder. ``None`` indicates no quantization. + + """ + if self.use_composable_quantizer: + return self.quantizer.set_node_finder(quantization_config, node_finder) # type: ignore[union-attr,return-value] + raise NotImplementedError( + "set_node_finder is only supported in the composable quantizer implementation." + ) + + def set_node_target( + self, node_target: OpOverload, quantization_config: Optional[QuantizationConfig] + ) -> TOSAQuantizer: + """Set quantization config for a specific operator target.""" + if self.use_composable_quantizer: + return self.quantizer.set_node_target(node_target, quantization_config) # type: ignore[union-attr,return-value] + raise NotImplementedError( + "set_node_target is only supported in the composable quantizer implementation." + ) + + def set_node_name( + self, node_name: str, quantization_config: Optional[QuantizationConfig] + ) -> TOSAQuantizer: + """Set quantization config for a specific node name.""" + if self.use_composable_quantizer: + return self.quantizer.set_node_name(node_name, quantization_config) # type: ignore[union-attr,return-value] + raise NotImplementedError( + "set_node_name is only supported in the composable quantizer implementation." + ) + + def transform_for_annotation(self, model: GraphModule) -> GraphModule: + """Transform the graph to prepare it for quantization annotation. + + Decomposes all operators where required to get correct quantization parameters. + + Args: + model (GraphModule): Model whose graph will be transformed. + + Returns: + GraphModule: Transformed model prepared for annotation. + + """ + return self.quantizer.transform_for_annotation(model) + + def annotate(self, model: GraphModule) -> GraphModule: + """Annotate the graph with the configured quantization settings. + + Currently only does static quantization annotation. + + Args: + model (GraphModule): Model to annotate statically. + + Returns: + GraphModule: Annotated model ready for export. + + """ + return self.quantizer.annotate(model) + + def validate(self, model: GraphModule) -> None: + """Validate the quantization results. Currently, this includes: + - Ensure tensor inputs to each operator live on the same device. + + Args: + model (GraphModule): GraphModule being validated. + Raises: + ValueError: If tensor inputs for any operator span more than one + device. + """ + for node in model.graph.nodes: + if node.op != "call_function": + continue + + devices = set() + for arg_node in node.all_input_nodes: + meta_val = arg_node.meta.get("val", None) + if meta_val is None: + continue + if isinstance(meta_val, (tuple, list)): + for tensor in meta_val: + devices.add( + str( + getattr( + tensor, + "device", + f"Could not get device from {tensor}", + ) + ) + ) + else: + devices.add( + str( + getattr( + meta_val, + "device", + f"Could not get device from {meta_val}", + ) + ) + ) + + if len(devices) > 1: + raise ValueError( + f"Quantizer detected operator {node.name} with different device inputs: {devices}." + ) + + def quantize_with_submodules( + self, + model: GraphModule, + calibration_samples: list[tuple], + is_qat: bool = False, + fold_quantize: bool = True, + ): + """Quantizes a GraphModule in a way such that conditional submodules are + handled properly. + + Note: torchao's prepare_pt2e and convert_pt2e natively handle + while_loop body_fn submodules, so we only manually process cond + branches and while_loop cond_fn here. + + Args: + model (GraphModule): The model to quantize. + calibration_samples (list[tuple]): A list of inputs to used to + calibrate the model during quantization. To properly calibrate a + model with submodules, at least one sample per code path is + needed. + is_qat (bool): Whether to do quantization aware training or not. + fold_quantize (bool): Enables or disables constant folding when quantization + is completed. + + Returns: + GraphModule: The quantized model. """ + prepare_fn = prepare_qat_pt2e if is_qat else prepare_pt2e + + prepared = prepare_fn(model, self) + # Prepare conditional submodules (e.g., if/while bodies) + # prepare only cond branches and while_loop cond_fn + for name, submodule, _ in get_cond_while_submodules_nested( + prepared, apply_quantization=True + ): + prepared.set_submodule(name, prepare_fn(submodule, self), strict=True) + for submodule_node in submodule.graph.nodes: + if is_submodule_node(submodule_node): + for nested_name, nested_sub, _ in get_cond_while_submodules_nested( + submodule, apply_quantization=True + ): + prepared.set_submodule( + nested_name, prepare_fn(nested_sub, self), strict=True + ) + + for inp in calibration_samples: + prepared(*inp) + + # Prepare conditional submodules (e.g., if/while bodies) + # convert only cond branches and while_loop cond_fn + for _, submodule, _ in get_cond_while_submodules_nested( + prepared, apply_quantization=True + ): + converted = convert_pt2e(submodule) + for submodule_node in submodule.graph.nodes: + if is_submodule_node(submodule_node): + for nested_name, nested_sub, _ in get_cond_while_submodules_nested( + submodule, apply_quantization=True + ): + converted.set_submodule( + nested_name, convert_pt2e(nested_sub), strict=True + ) + + return convert_pt2e(prepared) + + +class _TOSAQuantizerV1(Quantizer): + + def __init__( + self, compile_spec_or_tosa_spec: TosaSpecification | ArmCompileSpec + ) -> None: + super().__init__() + self.compile_spec: ArmCompileSpec + if isinstance(compile_spec_or_tosa_spec, TosaSpecification): + from executorch.backends.arm.tosa.compile_spec import TosaCompileSpec + + self.compile_spec = TosaCompileSpec(compile_spec_or_tosa_spec) + self.tosa_spec = self.compile_spec.tosa_spec + elif isinstance(compile_spec_or_tosa_spec, ArmCompileSpec): + self.compile_spec = compile_spec_or_tosa_spec + self.tosa_spec = self.compile_spec.tosa_spec + else: + raise TypeError( + f"TOSAQuantizer constructor expects " + f"a TosaSpecification or compile_spec list, " + f"got {type(compile_spec_or_tosa_spec)}" + ) + + self.global_config: Optional[QuantizationConfig] = None + self.io_config: Optional[QuantizationConfig] = None + self.module_type_config: Dict[Callable, Optional[QuantizationConfig]] = {} + self.module_name_config: Dict[str, Optional[QuantizationConfig]] = {} + + def set_global( + self, quantization_config: Optional[QuantizationConfig] + ) -> _TOSAQuantizerV1: + + self.global_config = quantization_config + return self + + def set_module_type( + self, module_type: Callable, quantization_config: Optional[QuantizationConfig] + ) -> _TOSAQuantizerV1: + + self.module_type_config[module_type] = quantization_config + return self + + def set_module_name( + self, module_name: str, quantization_config: Optional[QuantizationConfig] + ) -> _TOSAQuantizerV1: + + # Validate that quantization_config is provided + self.module_name_config[module_name] = quantization_config + return self + + def set_io( + self, quantization_config: Optional[QuantizationConfig] + ) -> _TOSAQuantizerV1: self.io_config = quantization_config return self @@ -469,18 +805,6 @@ def _set_disallow_tfa_for_nodes(self, model: GraphModule) -> None: node.meta[DISALLOW_TFA_META_KEY] = config is None def transform_for_annotation(self, model: GraphModule) -> GraphModule: - """Transform the graph to prepare it for quantization annotation. - - Currently transforms scalar values to tensor attributes. - - Args: - model (GraphModule): Model whose graph will be transformed. - - Returns: - GraphModule: Transformed model prepared for annotation. - - """ - self._set_disallow_tfa_for_nodes(model) # TODO: Fix the need to lazily import this. @@ -490,17 +814,6 @@ def transform_for_annotation(self, model: GraphModule) -> GraphModule: return pass_manager.transform_for_annotation_pipeline(graph_module=model) def annotate(self, model: GraphModule) -> GraphModule: - """Annotate the graph with the configured quantization settings. - - Currently only does static quantization annotation. - - Args: - model (GraphModule): Model to annotate statically. - - Returns: - GraphModule: Annotated model ready for export. - - """ model = self._annotate_for_static_quantization_config(model) return model @@ -597,116 +910,194 @@ def _annotate_io( mark_node_as_annotated(node) def validate(self, model: GraphModule) -> None: - """Validate the quantization results. Currently, this includes: - - Ensure tensor inputs to each operator live on the same device. + # Validation is handled by TOSAQuantizer.validate; keep no-op for + # Quantizer interface compatibility. + return None + + +class _TOSAQuantizerV2(ComposableQuantizer): + + def __init__( + self, compile_spec_or_tosa_spec: TosaSpecification | ArmCompileSpec + ) -> None: + self.compile_spec: ArmCompileSpec + if isinstance(compile_spec_or_tosa_spec, TosaSpecification): + from executorch.backends.arm.tosa.compile_spec import TosaCompileSpec + + self.compile_spec = TosaCompileSpec(compile_spec_or_tosa_spec) + self.tosa_spec = self.compile_spec.tosa_spec + elif isinstance(compile_spec_or_tosa_spec, ArmCompileSpec): + self.compile_spec = compile_spec_or_tosa_spec + self.tosa_spec = self.compile_spec.tosa_spec + else: + raise TypeError( + f"TOSAQuantizer constructor expects " + f"a TosaSpecification or compile_spec list, " + f"got {type(compile_spec_or_tosa_spec)}" + ) + + self.pattern_matcher = PatternMatcher(TOSA_QUANTIZER_SUPPORT_DICT) + self.shared_qspec_quantizer = SharedQspecQuantizer() + self.global_quantizer: Quantizer | None = None + self.global_config: Optional[QuantizationConfig] = None + self._quantizers: List[Quantizer] = [] + self._graph_annotations: dict[Node, QuantizationAnnotation] = {} + + @property + def quantizers(self) -> List[Quantizer]: + """Returns the configured quantizers in order of precedence, ensuring + the global config and shared_qspec_quantizer are applied last. + + The returned list is a shallow copy; quantizer instances are shared. - Args: - model (GraphModule): GraphModule being validated. - Raises: - ValueError: If tensor inputs for any operator span more than one - device. """ + quantizers = self._quantizers.copy() + if self.global_quantizer is not None: + quantizers.append(self.global_quantizer) + quantizers.append(self.shared_qspec_quantizer) + + return quantizers + + @quantizers.setter + def quantizers(self, value: List[Quantizer]) -> None: + """Override of quantizers setter to allow for dynamic updating of + quantizers without accessing self._quantizers. + """ + self._quantizers = value + + def annotate(self, model): + reporter = QuantizerReporter(self.quantizers, "FINAL QUANTIZATION REPORT") + model = super().annotate(model) + reporter.log_quantizer_report(model) + return model + + def _remove_annotations(self, model: GraphModule) -> GraphModule: for node in model.graph.nodes: - if node.op != "call_function": - continue + if Q_ANNOTATION_KEY in node.meta: + del node.meta[Q_ANNOTATION_KEY] + if ArmAnnotationInfo.CUSTOM_META_KEY in node.meta: + del node.meta[ArmAnnotationInfo.CUSTOM_META_KEY] + if DISALLOW_TFA_META_KEY in node.meta: + del node.meta[DISALLOW_TFA_META_KEY] + if PatternMatcher.Q_PATTERN_MATCHED_KEY in node.meta: + del node.meta[PatternMatcher.Q_PATTERN_MATCHED_KEY] + + # Clear quantizer internal annotation tracking + self._graph_annotations.clear() - devices = set() - for arg_node in node.all_input_nodes: - meta_val = arg_node.meta.get("val", None) - if meta_val is None: - continue - if isinstance(meta_val, (tuple, list)): - for tensor in meta_val: - devices.add( - str( - getattr( - tensor, - "device", - f"Could not get device from {tensor}", - ) - ) - ) - else: - devices.add( - str( - getattr( - meta_val, - "device", - f"Could not get device from {meta_val}", - ) - ) - ) + return model - if len(devices) > 1: - raise ValueError( - f"Quantizer detected operator {node.name} with different device inputs: {devices}." - ) + def transform_for_annotation(self, model: GraphModule) -> GraphModule: + # Transform_for_annotation should only decompose ops if quantized, which is + # indicated either by node.meta['DISALLOW_TFA_META_KEY']==False or no such key + # existing in the dict. This means that ops are assumed to be quantized by + # default and we need to explicitly annotate all non-quantized nodes with + # DISALLOW_TFA_META_KEY=True before calling the pass manager. + + # For _TOSAQuantizerV2 there is no simple filter which directly finds unquantized + # nodes since nodes can be annotated by any quantizer. Instead, self.annotate is + # run to set DISALLOW_TFA_META_KEY for quantized nodes and all nodes missing + # this key afterwards are set to DISALLOW_TFA_META_KEY=True. + reporter = QuantizerReporter( + self.quantizers, "PRE-TRANSFORM_FOR_ANNOTATION QUANTIZATION REPORT" # type: ignore[arg-type] + ) + model = super().annotate(model) + reporter.log_quantizer_report(model) + for node in model.graph.nodes: + if DISALLOW_TFA_META_KEY not in node.meta: + node.meta[DISALLOW_TFA_META_KEY] = True - def quantize_with_submodules( - self, - model: GraphModule, - calibration_samples: list[tuple], - is_qat: bool = False, - fold_quantize: bool = True, - ): - """Quantizes a GraphModule in a way such that conditional submodules are - handled properly. + # TODO: Fix the need to lazily import this. + from executorch.backends.arm._passes import ArmPassManager - Note: torchao's prepare_pt2e and convert_pt2e natively handle - while_loop body_fn submodules, so we only manually process cond - branches and while_loop cond_fn here. + pass_manager = ArmPassManager(self.compile_spec) + transformed_model = pass_manager.transform_for_annotation_pipeline(model) - Args: - model (GraphModule): The model to quantize. - calibration_samples (list[tuple]): A list of inputs to used to - calibrate the model during quantization. To properly calibrate a - model with submodules, at least one sample per code path is - needed. - is_qat (bool): Whether to do quantization aware training or not. - fold_quantize (bool): Enables or disables constant folding when quantization - is completed. + # Remove the temporary annotations + return self._remove_annotations(transformed_model) - Returns: - GraphModule: The quantized model. + def add_quantizer(self, quantizer: Quantizer) -> _TOSAQuantizerV2: + """Insert a quantizer with highest precedence.""" + self._quantizers.insert(0, quantizer) + return self + + def set_node_finder( + self, quantization_config: Optional[QuantizationConfig], node_finder: NodeFinder + ) -> _TOSAQuantizerV2: + """Add a quantizer targeting nodes found by the provided finder. + + ``None`` indicates no quantization for matched nodes. """ - prepare_fn = prepare_qat_pt2e if is_qat else prepare_pt2e + quantizer = PatternQuantizer( + quantization_config, node_finder, self.pattern_matcher + ) + self.add_quantizer(quantizer) + return self - prepared = prepare_fn(model, self) - # Prepare conditional submodules (e.g., if/while bodies) - # prepare only cond branches and while_loop cond_fn - for name, submodule, _ in get_cond_while_submodules_nested( - prepared, apply_quantization=True - ): - prepared.set_submodule(name, prepare_fn(submodule, self), strict=True) - for submodule_node in submodule.graph.nodes: - if is_submodule_node(submodule_node): - for nested_name, nested_sub, _ in get_cond_while_submodules_nested( - submodule, apply_quantization=True - ): - prepared.set_submodule( - nested_name, prepare_fn(nested_sub, self), strict=True - ) + def set_global( + self, quantization_config: Optional[QuantizationConfig] + ) -> _TOSAQuantizerV2: + """Set the default quantization config for all nodes. - for inp in calibration_samples: - prepared(*inp) + ``None`` indicates no quantization. - # Prepare conditional submodules (e.g., if/while bodies) - # convert only cond branches and while_loop cond_fn - for _, submodule, _ in get_cond_while_submodules_nested( - prepared, apply_quantization=True - ): - converted = convert_pt2e(submodule) - for submodule_node in submodule.graph.nodes: - if is_submodule_node(submodule_node): - for nested_name, nested_sub, _ in get_cond_while_submodules_nested( - submodule, apply_quantization=True - ): - converted.set_submodule( - nested_name, convert_pt2e(nested_sub), strict=True - ) + """ + node_finder = GlobalNodeFinder() + self.global_quantizer = PatternQuantizer( + quantization_config, node_finder, self.pattern_matcher + ) + self.global_config = quantization_config + return self - return convert_pt2e(prepared) + def set_node_target( + self, node_target: OpOverload, quantization_config: Optional[QuantizationConfig] + ) -> _TOSAQuantizerV2: + """Set quantization config for a specific operator target.""" + node_finder = NodeTargetNodeFinder(node_target) + self.set_node_finder(quantization_config, node_finder) + return self + + def set_node_name( + self, node_name: str, quantization_config: Optional[QuantizationConfig] + ) -> _TOSAQuantizerV2: + """Set quantization config for a specific node name.""" + node_finder = NodeNameNodeFinder(node_name) + self.set_node_finder(quantization_config, node_finder) + return self + + def set_module_type( + self, module_type: Callable, quantization_config: Optional[QuantizationConfig] + ) -> _TOSAQuantizerV2: + """Set quantization config for nodes originating from a module type.""" + node_finder = ModuleTypeNodeFinder(module_type) + self.set_node_finder(quantization_config, node_finder) + return self + + def set_module_name( + self, module_name: str, quantization_config: Optional[QuantizationConfig] + ) -> _TOSAQuantizerV2: + """Set quantization config for nodes originating from a module name.""" + node_finder = ModuleNameNodeFinder(module_name) + self.set_node_finder(quantization_config, node_finder) + return self + + def set_io( + self, quantization_config: Optional[QuantizationConfig] + ) -> _TOSAQuantizerV2: + """Set quantization_config for input and output nodes. + + Args: + quantization_config (Optional[QuantizationConfig]): Configuration + describing activation quantization for model inputs and outputs. + ``None`` indicates no quantization. + + """ + input_finder = InputNodeFinder() + output_finder = OutputNodeFinder() + self.set_node_finder(quantization_config, input_finder) + self.set_node_finder(quantization_config, output_finder) + return self class EthosUQuantizer(TOSAQuantizer): @@ -715,11 +1106,16 @@ class EthosUQuantizer(TOSAQuantizer): Args: compile_spec (EthosUCompileSpec): Backend compile specification for Ethos-U targets. + use_composable_quantizer (bool): Whether to use the composable quantizer implementation. See https://github.com/pytorch/executorch/issues/17701" for details. """ - def __init__(self, compile_spec: EthosUCompileSpec) -> None: - super().__init__(compile_spec) + def __init__( + self, + compile_spec: EthosUCompileSpec, + use_composable_quantizer: bool = False, + ) -> None: + super().__init__(compile_spec, use_composable_quantizer) class VgfQuantizer(TOSAQuantizer): @@ -728,8 +1124,13 @@ class VgfQuantizer(TOSAQuantizer): Args: compile_spec (VgfCompileSpec): Backend compile specification for Vgf targets. + use_composable_quantizer (bool): Whether to use the composable quantizer implementation. See https://github.com/pytorch/executorch/issues/17701" for details. """ - def __init__(self, compile_spec: VgfCompileSpec) -> None: - super().__init__(compile_spec) + def __init__( + self, + compile_spec: VgfCompileSpec, + use_composable_quantizer: bool = False, + ) -> None: + super().__init__(compile_spec, use_composable_quantizer) diff --git a/backends/arm/quantizer/quantization_config.py b/backends/arm/quantizer/quantization_config.py index 7e201644262..e6c53ebf966 100644 --- a/backends/arm/quantizer/quantization_config.py +++ b/backends/arm/quantizer/quantization_config.py @@ -13,13 +13,17 @@ from dataclasses import dataclass +from typing import Any, Callable, cast, Optional import torch +from torch.fx import Node from torchao.quantization.pt2e import ObserverOrFakeQuantize from torchao.quantization.pt2e.quantizer import ( DerivedQuantizationSpec, QuantizationSpec, + QuantizationSpecBase, + SharedQuantizationSpec, ) @@ -31,27 +35,29 @@ class QuantizationConfig: expose validated accessors. Attributes: - input_activation (QuantizationSpec | None): Spec for input activations. - output_activation (QuantizationSpec | None): Spec for output activations. - weight (QuantizationSpec | None): Spec for weights. - bias (QuantizationSpec | None): Spec for bias values. + input_activation (Optional[QuantizationSpec]): Spec for input activations. + output_activation (Optional[QuantizationSpec]): Spec for output activations. + weight (Optional[QuantizationSpec]): Spec for weights. + bias (Optional[QuantizationSpec]): Spec for bias values. """ - input_activation: QuantizationSpec | None - output_activation: QuantizationSpec | None - weight: QuantizationSpec | None - bias: QuantizationSpec | None + input_activation: Optional[QuantizationSpecBase] + output_activation: Optional[QuantizationSpecBase] + weight: Optional[QuantizationSpecBase] + bias: Optional[QuantizationSpecBase] | Callable[[Any], Any] - def get_input_act_qspec(self) -> QuantizationSpec | None: + def get_input_act_qspec( + self, node: Optional[Node] = None, input_node: Optional[Node] = None + ) -> Optional[QuantizationSpecBase]: """Get the validated input activation spec. Validate that the input activation qscheme is supported before returning the spec. Returns: - QuantizationSpec | None: Input activation spec, or ``None`` when - unset. + Optional[QuantizationSpecBase]: Input activation spec, or ``None`` when + unset. The ``node`` and ``input_node`` arguments are used by subclasses. Raises: ValueError: If the qscheme is not per-tensor affine or symmetric. @@ -60,7 +66,9 @@ def get_input_act_qspec(self) -> QuantizationSpec | None: if self.input_activation is None: return None # Validate that input_activation uses a supported qscheme - if self.input_activation.qscheme not in [ + if not hasattr( + self.input_activation, "qscheme" + ) or self.input_activation.qscheme not in [ torch.per_tensor_affine, torch.per_tensor_symmetric, ]: @@ -69,15 +77,18 @@ def get_input_act_qspec(self) -> QuantizationSpec | None: ) return self.input_activation - def get_output_act_qspec(self) -> QuantizationSpec | None: + def get_output_act_qspec( + self, node: Optional[Node] = None + ) -> Optional[QuantizationSpecBase]: """Get the validated output activation spec. Validate that the output activation qscheme is supported before returning the spec. Returns: - QuantizationSpec | None: Output activation spec, or ``None`` when - unset. + Optional[QuantizationSpecBase]: Output activation spec, or ``None`` when + unset. The ``node`` argument is currently unused and kept for + API parity. Raises: ValueError: If the qscheme is not per-tensor affine or symmetric. @@ -86,7 +97,9 @@ def get_output_act_qspec(self) -> QuantizationSpec | None: if self.output_activation is None: return None # Validate that output_activation uses a supported qscheme - if self.output_activation.qscheme not in [ + if not hasattr( + self.output_activation, "qscheme" + ) or self.output_activation.qscheme not in [ torch.per_tensor_affine, torch.per_tensor_symmetric, ]: @@ -95,14 +108,16 @@ def get_output_act_qspec(self) -> QuantizationSpec | None: ) return self.output_activation - def get_weight_qspec(self) -> QuantizationSpec | None: + def get_weight_qspec( + self, node: Optional[Node] = None + ) -> Optional[QuantizationSpecBase]: """Get the validated weight spec. Validate that the weight qscheme is supported (per-tensor or per-channel symmetric) before returning the spec. Returns: - QuantizationSpec | None: Weight spec, or ``None`` when unset. + Optional[QuantizationSpecBase]: Weight spec, or ``None`` when unset. Raises: ValueError: If the qscheme is not a supported symmetric scheme. @@ -111,25 +126,27 @@ def get_weight_qspec(self) -> QuantizationSpec | None: if self.weight is None: return None # Validate that weight uses a supported qscheme - if self.weight.qscheme not in [ + if not hasattr(self.weight, "qscheme") or self.weight.qscheme not in [ torch.per_tensor_symmetric, torch.per_channel_symmetric, ]: raise ValueError(f"Unsupported quantization_spec {self.weight} for weight") return self.weight - def get_bias_qspec(self, node: torch.fx.Node) -> QuantizationSpec | None: + def get_bias_qspec( + self, node: Optional[Node] = None + ) -> Optional[QuantizationSpecBase] | Callable[[Any], Any]: """Get the derived or validated bias spec. For conv/linear ops, derive bias qparams from the input/weight observers. Otherwise, validate a user-provided floating-point bias spec. Args: - node (torch.fx.Node): Node whose bias spec is requested. + node (Optional[Node]): Node whose bias spec is requested. Returns: - QuantizationSpec | None: Derived or provided bias spec, or ``None`` - when unset. + Optional[QuantizationSpecBase]: Derived or provided bias spec, or + ``None`` when unset. Raises: ValueError: If deriving qparams sees an unexpected number of @@ -138,6 +155,9 @@ def get_bias_qspec(self, node: torch.fx.Node) -> QuantizationSpec | None: """ + if self.bias is None or node is None: + return None + def _derive_qparams_fn( obs_or_fqs: list[ObserverOrFakeQuantize], ) -> tuple[torch.Tensor, torch.Tensor]: @@ -185,6 +205,12 @@ def _derive_qparams_fn( raise ValueError( "Input activation and weight QuantizationConfig must be specified." ) + if not isinstance( + self.input_activation, QuantizationSpec + ) or not isinstance(self.weight, QuantizationSpec): + raise ValueError( + "QuantizationConfig input_activation and weight must be instances of QuantizationSpec." + ) if (self.input_activation.dtype == self.weight.dtype == torch.int8) or ( self.input_activation.dtype == torch.int16 @@ -211,7 +237,7 @@ def _derive_qparams_fn( ch_axis = 0 quantization_spec = DerivedQuantizationSpec( - derived_from=[(input_act, node), (weight, node)], # type: ignore[list-item] + derived_from=((input_act, node), (weight, node)), # type: ignore[arg-type] derive_qparams_fn=_derive_qparams_fn, dtype=torch.int32, quant_min=torch.iinfo(torch.int32).min + 1, @@ -225,11 +251,95 @@ def _derive_qparams_fn( f"Bias quantization of types: i:{self.input_activation.dtype}, w:{self.weight.dtype} not implemented" ) - if self.bias is None: - return None - # Validate that bias dtype is floating-point - if self.bias.dtype != torch.float: - raise ValueError( - "Only float dtype for bias is supported for bias right now" - ) return self.bias + + +class TOSAQuantizationConfig(QuantizationConfig): + """Configures quantization, while enforcing TOSA specific constraints.""" + + SHARED_OUTPUT_ACT_QSPEC_PATTERNS = { + torch.ops.aten.adaptive_avg_pool2d.default, + torch.ops.aten.upsample_bilinear2d.vec, + torch.ops.aten.upsample_nearest2d.vec, + torch.ops.aten.avg_pool2d.default, + torch.ops.aten.max_pool2d.default, + torch.ops.aten.mean.default, + torch.ops.aten.mean.dim, + torch.ops.aten.silu.default, + torch.ops.aten.silu_.default, + } + + SHARED_INPUT_ACT_QSPEC_PATTERNS = { + torch.ops.aten.lt.Tensor, + torch.ops.aten.le.Tensor, + torch.ops.aten.gt.Tensor, + torch.ops.aten.ge.Tensor, + torch.ops.aten.eq.Tensor, + torch.ops.aten.ne.Tensor, + } + + def get_input_act_qspec(self, node=None, input_node=None): + """Return the configured input quantization spec. + + For comparison operators, make sure that both inputs share the same + quantization spec, by returning a SharedQuantizationSpec that ties the + quantization of both inputs together. For other operators, return the + default input activation spec. + + """ + if node is None or input_node is None: + return super().get_input_act_qspec(node, input_node) + + if node.target in self.SHARED_INPUT_ACT_QSPEC_PATTERNS: + if input_node == node.args[0]: + return super().get_input_act_qspec(node, input_node) + else: + return SharedQuantizationSpec((node.args[0], node)) + + return super().get_input_act_qspec(node, input_node) + + def get_weight_qspec( + self, node: Optional[Node] = None + ) -> Optional[QuantizationSpecBase]: + """Return the configured weight quantization spec. + + For conv transpose, return the per-channel quantization spec with + `ch_axis=1` to match the IOHW weight format used by TOSA, instead of + the default `ch_axis=0`. If no weight spec is configured, return + ``None``. + + """ + weight_qspec = super().get_weight_qspec() + if ( + node is not None + and weight_qspec is not None + and isinstance(weight_qspec, QuantizationSpec) + and weight_qspec.qscheme == torch.per_channel_symmetric + and node.target == torch.ops.aten.conv_transpose2d.input + ): + # MLETORCH-1853: Fix lazy import when moving files around + from executorch.backends.arm.quantizer.quantization_annotator import ( + _adjust_weight_qspec_for_conv_transpose, + ) + + weight_qspec = _adjust_weight_qspec_for_conv_transpose(node, weight_qspec) + + return weight_qspec + + def get_output_act_qspec( + self, node: Optional[Node] = None + ) -> Optional[QuantizationSpecBase]: + """Return the configured output activation quantization spec. + + If node is a pooling or upsample operator, returns a shared quantization spec. + If no weight spec is configured, return ``None``. + + """ + + if node is None: + return super().get_output_act_qspec() + if node.target not in self.SHARED_OUTPUT_ACT_QSPEC_PATTERNS: + return super().get_output_act_qspec() + if len(node.args) == 0: + return super().get_output_act_qspec() + return SharedQuantizationSpec((cast(Node, node.args[0]), node)) diff --git a/backends/arm/quantizer/quantizer_support.py b/backends/arm/quantizer/quantizer_support.py new file mode 100644 index 00000000000..3558cfac792 --- /dev/null +++ b/backends/arm/quantizer/quantizer_support.py @@ -0,0 +1,180 @@ +# Copyright 2026 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from itertools import product + +import torch +from executorch.backends.arm._passes.arm_pass_utils import get_first_fake_tensor +from executorch.backends.arm.quantizer.quantization_annotator import ( + _conv_ops, + _one_to_one, +) +from executorch.backends.cortex_m.quantizer.pattern_checkers import PatternCheck +from torch._ops import OpOverload + + +def combo_pattern(*pattern_lists): + "Returns the cartesian product of the given pattern lists." + return [tuple(pattern) for pattern in product(*pattern_lists)] + + +class ReluFusedPatternCheck(PatternCheck): + @classmethod + def check_quantization_config(cls, pattern, quantization_config): + if quantization_config is None: + return True + + output_node = pattern[-1] if pattern else None + output_qspec = quantization_config.get_output_act_qspec(output_node) + if output_qspec is None: + return False + + return output_qspec.qscheme not in ( + torch.per_tensor_symmetric, + torch.per_channel_symmetric, + ) + + +class ArithmeticFloatInputsCheck(PatternCheck): + @classmethod + def check_pattern(cls, pattern): + """For arithmetic ops all inputs must be quantizeable for quantization + to make sense. + """ + for node in pattern: + for input_node in node.all_input_nodes: + try: + tensor = get_first_fake_tensor(input_node) + except Exception: + return False + if not tensor.dtype.is_floating_point: + return False + + return True + + +BINARY_OP_PATTERNS = [ + (torch.ops.aten.add.Tensor,), + (torch.ops.aten.add_.Tensor,), + (torch.ops.aten.sub.Tensor,), + (torch.ops.aten.sub_.Tensor,), + (torch.ops.aten.matmul.default,), + (torch.ops.aten.mm.default,), + (torch.ops.aten.bmm.default,), + (torch.ops.aten.mul.Tensor,), + (torch.ops.aten.mul_.Tensor,), +] +ACTIVATION_FUNCTION_PATTERNS = [ + (torch.ops.aten.hardswish.default,), + (torch.ops.aten.hardswish_.default,), +] + +LINEAR_OPS = [torch.ops.aten.linear.default] +FUSED_ACTIVATION_OPS = [ + torch.ops.aten.relu.default, + torch.ops.aten.relu_.default, + torch.ops.aten.hardtanh.default, + torch.ops.aten.hardtanh_.default, + torch.ops.aten.hardsigmoid.default, + torch.ops.aten.hardsigmoid_.default, + torch.ops.aten.clamp.default, + torch.ops.aten.clamp_.default, +] +BATCH_NORM_OPS = [torch.ops.aten.batch_norm.default] +LINEAR_OP_PATTERNS = ( + combo_pattern(LINEAR_OPS) + + combo_pattern(LINEAR_OPS, FUSED_ACTIVATION_OPS) + + combo_pattern(LINEAR_OPS, BATCH_NORM_OPS) + + combo_pattern(LINEAR_OPS, BATCH_NORM_OPS, FUSED_ACTIVATION_OPS) +) +CONV_OP_PATTERNS = ( + combo_pattern(_conv_ops) + + combo_pattern(_conv_ops, FUSED_ACTIVATION_OPS) + + combo_pattern(_conv_ops, BATCH_NORM_OPS) + + combo_pattern(_conv_ops, BATCH_NORM_OPS, FUSED_ACTIVATION_OPS) +) +FUSED_RELU_OP_PATTERNS = ( + combo_pattern(LINEAR_OPS, FUSED_ACTIVATION_OPS) + + combo_pattern(LINEAR_OPS, BATCH_NORM_OPS, FUSED_ACTIVATION_OPS) + + combo_pattern(_conv_ops, FUSED_ACTIVATION_OPS) + + combo_pattern(_conv_ops, BATCH_NORM_OPS, FUSED_ACTIVATION_OPS) +) + +ALL_QPARAM_OP_PATTERNS = ( + [(target,) for target in _one_to_one] + + ACTIVATION_FUNCTION_PATTERNS + + CONV_OP_PATTERNS + + LINEAR_OP_PATTERNS + + BINARY_OP_PATTERNS + + [ + (torch.ops.aten.full.default,), + (torch.ops.aten.full,), + (torch.ops.aten.zeros.default,), + (torch.ops.aten.ones.default,), + (torch.ops.aten.fill_.Scalar,), + (torch.ops.aten.scalar_tensor.default,), + (torch.ops.aten.zeros_like.default,), + (torch.ops.aten._softmax.default,), + (torch.ops.aten.softmax.int,), + (torch.ops.aten.div.Tensor,), + (torch.ops.aten.div_.Tensor,), + (torch.ops.aten.div.Tensor_mode,), + (torch.ops.aten.floor,), + (torch.ops.aten.floor_divide.default,), + (torch.ops.aten.logit.default,), + (torch.ops.aten.glu.default,), + (torch.ops.aten.addmm.default,), + (torch.ops.aten.layer_norm.default,), + (torch.ops.aten.group_norm.default,), + (torch.ops.aten.sqrt.default,), + (torch.ops.aten.silu.default,), + (torch.ops.aten.silu_.default,), + (torch.ops.aten.logit.default,), + (torch.ops.aten.var.dim,), + (torch.ops.aten.var.correction,), + (torch.ops.aten.leaky_relu.default,), + (torch.ops.aten.leaky_relu_.default,), + (torch.ops.aten.linalg_vector_norm.default,), + (torch.ops.aten.log_softmax.int,), + (torch.ops.aten.round.default,), + (torch.ops.aten.arange.start_step,), + (torch.ops.aten.embedding.default,), + (torch.ops.aten.adaptive_avg_pool2d.default,), + (torch.ops.aten.upsample_bilinear2d.vec,), + (torch.ops.aten.upsample_nearest2d.vec,), + (torch.ops.aten.avg_pool2d.default,), + (torch.ops.aten.max_pool2d.default,), + (torch.ops.aten.cosine_similarity.default,), + (torch.ops.aten.sigmoid.default,), + (torch.ops.aten.remainder.Tensor,), + (torch.ops.aten.remainder.Scalar,), + (torch.ops.aten.mean.dim,), + (torch.ops.aten.mean.default,), + (torch.ops.aten.neg.default,), + (torch.ops.aten.scaled_dot_product_attention.default,), + (torch.ops.aten.abs.default,), + (torch.ops.aten.minimum.default,), + (torch.ops.aten.maximum.default,), + (torch.ops.aten.lt.Tensor,), + (torch.ops.aten.le.Tensor,), + (torch.ops.aten.gt.Tensor,), + (torch.ops.aten.ge.Tensor,), + (torch.ops.aten.eq.Tensor,), + (torch.ops.aten.ne.Tensor,), + (torch.ops.aten.lt.Scalar,), + (torch.ops.aten.le.Scalar,), + (torch.ops.aten.gt.Scalar,), + (torch.ops.aten.ge.Scalar,), + (torch.ops.aten.eq.Scalar,), + (torch.ops.aten.ne.Scalar,), + ] +) +TOSA_QUANTIZER_SUPPORT_DICT: dict[tuple[OpOverload, ...], type[PatternCheck] | None] = { + pattern: None for pattern in ALL_QPARAM_OP_PATTERNS +} +for pattern in FUSED_RELU_OP_PATTERNS: + TOSA_QUANTIZER_SUPPORT_DICT[pattern] = ReluFusedPatternCheck +for pattern in BINARY_OP_PATTERNS: + TOSA_QUANTIZER_SUPPORT_DICT[pattern] = ArithmeticFloatInputsCheck diff --git a/backends/arm/test/quantizer/test_set_module_name.py b/backends/arm/test/quantizer/test_set_module_name.py index d0ca781256f..6ca7e3f970c 100644 --- a/backends/arm/test/quantizer/test_set_module_name.py +++ b/backends/arm/test/quantizer/test_set_module_name.py @@ -1,4 +1,4 @@ -# Copyright 2025 Arm Limited and/or its affiliates. +# Copyright 2025-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -113,13 +113,13 @@ def validate_node( if len(node.all_input_nodes) == 3: input_node, weight_node, bias_node = node.all_input_nodes bias_qspec = quantization_config.get_bias_qspec(node) - validate_input(bias_node, bias_qspec) + validate_input(bias_node, bias_qspec) # type: ignore[arg-type] else: input_node, weight_node = node.all_input_nodes - validate_input(input_node, input_qspec) - validate_input(weight_node, weight_qspec) - validate_output(node, output_qspec) + validate_input(input_node, input_qspec) # type: ignore[arg-type] + validate_input(weight_node, weight_qspec) # type: ignore[arg-type] + validate_output(node, output_qspec) # type: ignore[arg-type] def test_set_module_name_tosa_INT() -> None: diff --git a/backends/cortex_m/quantizer/pattern_matcher.py b/backends/cortex_m/quantizer/pattern_matcher.py index 7ed848d5e83..3c391546b85 100644 --- a/backends/cortex_m/quantizer/pattern_matcher.py +++ b/backends/cortex_m/quantizer/pattern_matcher.py @@ -45,7 +45,7 @@ class PatternMatcher: def __init__( self, - support_dict: dict[tuple[OpOverload, ...], PatternCheck], + support_dict: dict[tuple[OpOverload, ...], Optional[type[PatternCheck]]], support_dict_name: str | None = None, ): self.support_dict = support_dict diff --git a/backends/cortex_m/quantizer/quantization_configs.py b/backends/cortex_m/quantizer/quantization_configs.py index 030c0fa4d93..fd25db86fbb 100644 --- a/backends/cortex_m/quantizer/quantization_configs.py +++ b/backends/cortex_m/quantizer/quantization_configs.py @@ -86,7 +86,9 @@ class CortexMQuantizationConfig(QuantizationConfig): """Configures quantization, while enforcing cortex-m specific constraints.""" - def get_input_act_qspec(self, node: Node | None = None) -> QuantizationSpec | None: + def get_input_act_qspec( + self, node: Node | None = None, input_node: Node | None = None + ) -> QuantizationSpec | None: """ Returns the configured input activation spec, no specific adjustments. """ diff --git a/backends/cortex_m/quantizer/quantizer.py b/backends/cortex_m/quantizer/quantizer.py index 1e8aa1da47d..14f29dbc9a8 100644 --- a/backends/cortex_m/quantizer/quantizer.py +++ b/backends/cortex_m/quantizer/quantizer.py @@ -6,11 +6,13 @@ import logging import operator -from typing import List, Optional +from typing import Callable, List, Optional import torch from executorch.backends.arm.common.annotation_meta import ArmAnnotationInfo +from executorch.backends.arm.constants import DISALLOW_TFA_META_KEY + from executorch.backends.arm.quantizer.quantization_config import QuantizationConfig from executorch.backends.cortex_m.passes.cortex_m_pass_manager import CortexMPassManager from executorch.backends.cortex_m.quantizer.node_finders import ( @@ -36,7 +38,6 @@ CONV_TRANSPOSE_OP_PATTERNS, CORTEX_M_QUANTIZER_SUPPORT_DICT, ) -from torch._ops import OpOverload from torch.fx import GraphModule, Node from torchao.quantization.pt2e.quantizer import ( ComposableQuantizer, @@ -59,18 +60,34 @@ def has_float_output(node: Node) -> bool: def mark_node_as_annotated( node: Node, - input_qspec_map: dict[Node, Optional[QuantizationSpec]], - output_qspec: Optional[QuantizationSpec], - reporter: Optional[QuantizerReporter] = None, - quantizer: Optional[Quantizer] = None, + input_qspec_map, + output_qspec, + is_quantized, ) -> None: - node.meta[Q_ANNOTATION_KEY] = QuantizationAnnotation(input_qspec_map, output_qspec) - annotation_info = ArmAnnotationInfo( - quantized=True, + """Fills various meta data fields required for quantization, partitioning, + and backend lowering. + + Note: quantization_config is needed to distinguish between otherwise + identical annotations: + 1. Node explicitly marked as not quantized using quantization_config = None + 2. Node which is quantized but all inputs/outputs are quantized + """ + + # Annotate node for toracho + node.meta[Q_ANNOTATION_KEY] = QuantizationAnnotation( + input_qspec_map, output_qspec, _annotated=True ) - meta_custom = node.meta.get("custom", {}) - meta_custom[ArmAnnotationInfo.CUSTOM_META_KEY] = dict(annotation_info) - node.meta["custom"] = meta_custom + + # Mark operator nodes as quantized to be folded in fold_qdq_with_annotated_qparams and know what to partition + if node.op == "call_function": + meta_custom = node.meta.get("custom", {}) + meta_custom[ArmAnnotationInfo.CUSTOM_META_KEY] = ArmAnnotationInfo( + quantized=is_quantized + ) + node.meta["custom"] = meta_custom + + # Mark nodes to not be touched by transform_for_annotation in quantization dry-run + node.meta[DISALLOW_TFA_META_KEY] = not is_quantized class CortexMQuantizer(ComposableQuantizer): @@ -128,12 +145,12 @@ class PatternQuantizer(Quantizer, QuantizerReporterUser): def __init__( self, - quantization_config: QuantizationConfig, + quantization_config: QuantizationConfig | None, node_finder: NodeFinder, pattern_matcher: PatternMatcher, ) -> None: super().__init__() - self.quantization_config: QuantizationConfig = quantization_config + self.quantization_config: QuantizationConfig | None = quantization_config self.node_finder: NodeFinder = node_finder self.pattern_matcher: PatternMatcher = pattern_matcher @@ -141,7 +158,7 @@ def get_quantizer_info(self): name = self.__class__.__name__ targeted_nodes_description = str(self.node_finder) quantization_config_path = SUPPORTED_QCONFIGS.get( - self.quantization_config, "CUSTOM_QCONFIG" + self.quantization_config, "UNREGISTRED_QCONFIG" ) support_config_path = self.pattern_matcher.support_dict_name @@ -182,15 +199,34 @@ def annotate_match( - All other outputs coming out of the matched pattern are annotated as output activations. """ + parameter_targets = { + torch.ops.aten.linear.default, + torch.ops.aten.convolution.default, + torch.ops.aten.conv1d.default, + torch.ops.aten.conv1d.padding, + torch.ops.aten.conv2d.default, + torch.ops.aten.conv2d.padding, + torch.ops.aten.conv3d.default, + torch.ops.aten.conv3d.padding, + torch.ops.aten.conv_transpose2d.input, + } + for node in match: input_qspec_map = {} output_qspec = None params = [n for n in node.all_input_nodes if self.is_parameter(n, model)] # Check that the assumptions on number of parameters hold to avoid silent errors - assert ( - 0 <= len(params) <= 2 - ), f"{self.__class__.__name__} expected 0 params, 1 params (weight) or 2 params (weight, bias), but got {len(params)} for node {node}." + if node.target in parameter_targets: + if len(params) == 0 or len(params) > 2: + logger.warning( + f"{node.name} is expected to have parameter tensors for weight/bias but no such inputs found, which may cause unexpected quantization annotations. This is likely caused by incorrect tensor instantiations or non-constant weight/biases." + ) + else: + if len(params) > 0: + logger.warning( + f"{node.name} is not expected to not have parameter tensors but found {[n.name for n in params]}, which may cause unexpected quantization annotations." + ) for input_node in node.all_input_nodes: # Observers only work on floating point tensors, so make sure to skip other dtypes @@ -207,14 +243,18 @@ def annotate_match( ) elif input_node not in match: input_qspec_map[input_node] = ( - config.get_input_act_qspec() if config else None + config.get_input_act_qspec(node, input_node) if config else None ) if all(node not in match for node in node.users) and output_qspec is None: - output_qspec = config.get_output_act_qspec(node) if config else None + if has_float_output(node): + output_qspec = config.get_output_act_qspec(node) if config else None mark_node_as_annotated( - node, input_qspec_map, output_qspec, self.reporter, self + node, + input_qspec_map, + output_qspec, + config is not None, # None qconfig -> explicitly not quantized node ) def annotate(self, model: GraphModule) -> None: @@ -242,11 +282,11 @@ class SharedQspecQuantizer(Quantizer, QuantizerReporterUser): i.e. ops which does not change the scale such as clone, min/max, transposes and so on. Args: - targets (Optional[List[OpOverload]]): List of operator overloads to apply shared quantization spec to. - If None, a default list of supported ops is used. + targets (Optional[List[Callable[..., object]]]): List of operator targets to apply shared + quantization specs to. If None, a default list of supported ops is used. """ - SHARED_QSPEC_OPS_DEFAULT: List[OpOverload] = [ + SHARED_QSPEC_OPS_DEFAULT: List[Callable[..., object]] = [ # Clone torch.ops.aten.clone.default, torch.ops.aten.lift_fresh_copy.default, @@ -344,7 +384,7 @@ class SharedQspecQuantizer(Quantizer, QuantizerReporterUser): torch.ops.higher_order.cond, ] - def __init__(self, targets: Optional[List[OpOverload]] = None) -> None: + def __init__(self, targets: Optional[List[Callable[..., object]]] = None) -> None: super().__init__() if targets is None: self.targets = self.SHARED_QSPEC_OPS_DEFAULT @@ -439,6 +479,7 @@ def _annotate_shared_cluster(self, root_node: Node) -> None: root_node, {}, None, + is_quantized=True, ) return @@ -482,9 +523,7 @@ def _annotate_shared_cluster(self, root_node: Node) -> None: else: output_qspec = shared_qspec mark_node_as_annotated( - node, - input_qspec_map, - output_qspec, + node, input_qspec_map, output_qspec, is_quantized=True ) # Force the root qspec to be the adjacent spec diff --git a/backends/cortex_m/quantizer/quantizer_reporter.py b/backends/cortex_m/quantizer/quantizer_reporter.py index 8c5151a5c7f..2df704ea8a4 100644 --- a/backends/cortex_m/quantizer/quantizer_reporter.py +++ b/backends/cortex_m/quantizer/quantizer_reporter.py @@ -31,31 +31,66 @@ ) from tabulate import tabulate from torch.fx import GraphModule, Node -from torchao.quantization.pt2e.quantizer import Quantizer +from torchao.quantization.pt2e.quantizer import ( + DerivedQuantizationSpec, + QuantizationSpec, + QuantizationSpecBase, + Quantizer, + SharedQuantizationSpec, +) from torchao.quantization.pt2e.quantizer.quantizer import Q_ANNOTATION_KEY logger = logging.getLogger(__name__) # Look-up dicts used to get human readable names for supported quantization configs and specs -SUPPORTED_QCONFIGS = { +SUPPORTED_QCONFIGS: dict[any, str] = { INT8_PER_CHANNEL_CONFIG: f"{quantization_configs_module}.INT8_PER_CHANNEL_QCONFIG", INT8_PER_TENSOR_CONFIG: f"{quantization_configs_module}.INT8_PER_TENSOR_QCONFIG", } -SUPPORTED_QSPECS = { +SUPPORTED_QSPECS: dict[QuantizationSpecBase | None, str] = { INT8_ACTIVATION_PER_TENSOR_QSPEC: "INT8_ACTIVATION_PER_TENSOR_QSPEC", INT8_ACTIVATION_PER_CHANNEL_QSPEC: "INT8_ACTIVATION_PER_CHANNEL_QSPEC", INT8_WEIGHT_PER_TENSOR_QSPEC: "INT8_WEIGHT_PER_TENSOR_QSPEC", INT8_WEIGHT_PER_CHANNEL_QSPEC: "INT8_WEIGHT_PER_CHANNEL_QSPEC", INT8_WEIGHT_PER_CHANNEL_TRANSPOSE_QSPEC: "INT8_WEIGHT_PER_CHANNEL_TRANSPOSE_QSPEC", SOFTMAX_OUTPUT_FIXED_QSPEC: "SOFTMAX_OUTPUT_FIXED_QSPEC", - None: "None", } def _qspec_repr(qspec): - return SUPPORTED_QSPECS.get(qspec, "CUSTOM_QSPEC") + """ + Get a human readable representation of QuantizationSpecs. + + Note that the observer_or_fake_quant_ctr field is created dynamically with the qspec + so two qspecs created at different times will not evaluate as equal. Therefore a + custom comparison is required. + + #TODO: Clean up qconfig/ qspec string representation logic in cortex_m/arm backend. + """ + if isinstance(qspec, SharedQuantizationSpec): + return "SHARED_QSPEC" + elif isinstance(qspec, DerivedQuantizationSpec): + return "DERIVED_QSPEC" + elif qspec is None: + return "NO_QSPEC" + elif isinstance(qspec, QuantizationSpec): + for key, val in SUPPORTED_QSPECS.items(): + if type(qspec) is not type(key): + continue + if qspec.dtype != key.dtype: + continue + if qspec.quant_min != key.quant_min: + continue + if qspec.quant_max != key.quant_max: + continue + if qspec.qscheme != key.qscheme: + continue + if qspec.is_dynamic != key.is_dynamic: + continue + return val + return "UNREGISTRED_QSPEC" class QuantizerInfo(NamedTuple): @@ -283,8 +318,13 @@ class QuantizerReporter: inheriting from QuantizerReporterUser. """ - def __init__(self, quantizers: List[QuantizerReporterUser]): + def __init__( + self, + quantizers: List[QuantizerReporterUser], + report_title: str = "QUANTIZATION REPORT", + ): self.quantizers: Dict[Quantizer, QuantizerReport] = {} + self.report_title = report_title self.set_quantizers(quantizers) def set_quantizers(self, quantizers: List[QuantizerReporterUser]) -> None: @@ -357,7 +397,12 @@ def get_quantization_report( report_rows: List[str] = [] separator = "-" * 100 report_rows.append(separator) - report_rows.append(" " * 39 + " QUANTIZATION REPORT " + " " * 40) + assert ( + len(self.report_title) < 100 + ), "Report title must be less than 100 characters to be properly formatted in the report header." + pre_pad = (100 - len(self.report_title)) // 2 + post_pad = 100 - len(self.report_title) - pre_pad + report_rows.append(" " * pre_pad + f"{self.report_title}" + " " * post_pad) report_rows.append(separator) for report in self.quantizers.values(): diff --git a/backends/cortex_m/test/misc/test_quantizer_reporter.py b/backends/cortex_m/test/misc/test_quantizer_reporter.py index 354a1258f1f..368ff78793c 100644 --- a/backends/cortex_m/test/misc/test_quantizer_reporter.py +++ b/backends/cortex_m/test/misc/test_quantizer_reporter.py @@ -101,11 +101,13 @@ def test_debug_log_level(caplog): add_node, {add_node.args[0]: INT8_WEIGHT_PER_TENSOR_QSPEC, add_node.args[1]: None}, None, + is_quantized=True, ) mark_node_as_annotated( relu_node, {}, INT8_ACTIVATION_PER_CHANNEL_QSPEC, + is_quantized=True, ) quantizer1.report_accept([add_node, relu_node]) quantizer2.report_reject( @@ -128,8 +130,8 @@ def test_debug_log_level(caplog): NODE NAME INPUT QSPEC MAP OUTPUT QSPEC MAP -- ----------- ------------------------------- --------------------------------- - ╒ add x: INT8_WEIGHT_PER_TENSOR_QSPEC None - | y: None + ╒ add x: INT8_WEIGHT_PER_TENSOR_QSPEC NO_QSPEC + | y: NO_QSPEC ╘ relu INT8_ACTIVATION_PER_CHANNEL_QSPEC ---------------------------------------------------------------------------------------------------- DummyQuantizer using dummy nodes