diff --git a/backends/cortex_m/passes/BUCK b/backends/cortex_m/passes/BUCK index 58a705ea3c6..20444f16718 100644 --- a/backends/cortex_m/passes/BUCK +++ b/backends/cortex_m/passes/BUCK @@ -13,7 +13,6 @@ fbcode_target(_kind = runtime.python_library, name="replace_quant_nodes_pass", srcs=[ "replace_quant_nodes_pass.py", - "quantized_op_fusion_pass.py", ], deps=[ "//caffe2:torch", diff --git a/backends/cortex_m/passes/__init__.py b/backends/cortex_m/passes/__init__.py index cd1f2892de2..6d6783488fe 100644 --- a/backends/cortex_m/passes/__init__.py +++ b/backends/cortex_m/passes/__init__.py @@ -41,6 +41,5 @@ def _ensure_cortex_m_dependencies() -> None: from .decompose_hardswish_pass import DecomposeHardswishPass # noqa from .decompose_mean_pass import DecomposeMeanPass # noqa from .quantized_clamp_activation_pass import QuantizedClampActivationPass # noqa -from .quantized_op_fusion_pass import QuantizedOpFusionPass # noqa from .replace_quant_nodes_pass import ReplaceQuantNodesPass # noqa from .cortex_m_pass_manager import CortexMPassManager # noqa # usort: skip diff --git a/backends/cortex_m/passes/aten_to_cortex_m_pass.py b/backends/cortex_m/passes/aten_to_cortex_m_pass.py index e6fe1ec8c21..ecc7187797d 100644 --- a/backends/cortex_m/passes/aten_to_cortex_m_pass.py +++ b/backends/cortex_m/passes/aten_to_cortex_m_pass.py @@ -5,7 +5,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. -from typing import cast +import math +from typing import cast, Optional import cmsis_nn # type: ignore[import-not-found, import-untyped] import executorch.backends.cortex_m.ops.operators # noqa @@ -17,11 +18,17 @@ from executorch.backends.cortex_m.passes.passes_utils import ( build_activation_lut, quantize_multiplier_aot, + quantize_val, + SHIFT_INT8, to_physical_order, ) from executorch.backends.cortex_m.passes.scratch_buffer_sizes import ( required_cmsis_nn_buffer_sizes, ) +from executorch.backends.cortex_m.quantizer.quantization_configs import ( + CMSIS_SOFTMAX_SCALE, + CMSIS_SOFTMAX_ZERO_POINT, +) from executorch.backends.cortex_m.target_config import CortexMTargetConfig from executorch.backends.transforms.aten_to_dialect_pass import ( AtenToDialectPass, @@ -38,6 +45,7 @@ from torch.export import ExportedProgram from torch.export.graph_signature import InputKind from torch.fx import Node +from torch.fx.node import Argument from torch.fx.passes.infra.pass_manager import PassResult @@ -99,6 +107,78 @@ def _create_uninitialized_alloc_node( ) +_SOFTMAX_INPUT_INTEGER_BITS = 5 + + +def _to_int_pair( + value: Argument, default: Optional[tuple[int, int]] +) -> tuple[int, int]: + if value is None: + assert default is not None, "Expected default sequence for normalization" + return (default[0], default[1]) + + try: + int_pair = cast(tuple[int, int], value) + return int_pair + except Exception as exc: + raise ValueError(f"Expected a tuple of two integers, got {value}") from exc + + +def _to_bool(value: Argument, default: bool) -> bool: + if value is None: + return default + try: + bool_value = cast(bool, value) + return bool_value + except Exception as exc: + raise ValueError(f"Expected a boolean value, got {value}") from exc + + +def _is_quant_per_tensor_qualified(node: Node) -> bool: + """Match int8 OR int16 (de)quantize_per_tensor nodes.""" + dtype = node.args[5] + if dtype == torch.int8: + return ( + cast(int, node.args[3]) >= torch.iinfo(torch.int8).min + and cast(int, node.args[4]) <= torch.iinfo(torch.int8).max + ) + if dtype == torch.int16: + return ( + cast(int, node.args[3]) >= torch.iinfo(torch.int16).min + and cast(int, node.args[4]) <= torch.iinfo(torch.int16).max + ) + return False + + +def _compute_softmax_params(input_scale: float) -> tuple[int, int, int]: + """ + Convert per-tensor input scale into fixed-point params for arm_softmax_s8. + """ + real_multiplier = min( + input_scale * (1 << (31 - _SOFTMAX_INPUT_INTEGER_BITS)), + float((1 << 31) - 1), + ) + input_multiplier, input_shift = quantize_multiplier_aot(real_multiplier) + diff_min_term = ( + ((1 << _SOFTMAX_INPUT_INTEGER_BITS) - 1) + * math.ldexp(1.0, 31 - _SOFTMAX_INPUT_INTEGER_BITS) + / math.ldexp(1.0, input_shift) + ) + diff_min = -int(math.floor(diff_min_term)) + return int(input_multiplier), int(input_shift), diff_min + + +def _get_input_tensor_data(node: Node, arg_index: int = 0): + arg = node.args[arg_index] + if isinstance(arg, Node) and "val" in arg.meta: + return get_first_fake_tensor(arg) + if "val" in node.meta: + return get_first_fake_tensor(node) + raise KeyError( + f"Expected fake tensor metadata on input arg {arg_index} or node {node.name}." + ) + + def _compute_kernel_sum(weights, bias, input_offset, weight_offset): """ Computes the precomputed kernel sum term (bias optional) @@ -698,3 +778,302 @@ def _get_avg_pool2d_replacement( return DialectNodeSpec( exir_ops.edge.cortex_m.quantized_avg_pool2d.default, new_args ) + + +@AtenToCortexMPass.register_dialect_substitution( + exir_ops.edge.quantized_decomposed.quantize_per_tensor.default +) +def _get_quantize_per_tensor_replacement( + node: Node, dialect_pass: AtenToDialectPass +) -> DialectNodeSpec | None: + del dialect_pass + if not _is_quant_per_tensor_qualified(node): + return None + return DialectNodeSpec( + exir_ops.edge.cortex_m.quantize_per_tensor.default, node.args + ) + + +@AtenToCortexMPass.register_dialect_substitution( + exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default +) +def _get_dequantize_per_tensor_replacement( + node: Node, dialect_pass: AtenToDialectPass +) -> DialectNodeSpec | None: + del dialect_pass + if not _is_quant_per_tensor_qualified(node): + return None + return DialectNodeSpec( + exir_ops.edge.cortex_m.dequantize_per_tensor.default, node.args + ) + + +@AtenToCortexMPass.register_dialect_substitution(exir_ops.edge.aten.add.Tensor) +def _get_add_replacement( + node: Node, dialect_pass: AtenToDialectPass +) -> DialectNodeSpec | None: + del dialect_pass + if not _has_qparams(node): + return None + + scale1 = node.meta["input_qparams"][0].scale + zero_point1 = node.meta["input_qparams"][0].zp + scale2 = node.meta["input_qparams"][1].scale + zero_point2 = node.meta["input_qparams"][1].zp + output_scale = node.meta["output_qparams"][0].scale + output_zero_point = node.meta["output_qparams"][0].zp + + max_scale_2x = 2 * max(scale1, scale2) + input1_mult, input1_shift = quantize_multiplier_aot(scale1 / max_scale_2x) + input2_mult, input2_shift = quantize_multiplier_aot(scale2 / max_scale_2x) + output_mult, output_shift = quantize_multiplier_aot( + max_scale_2x / (output_scale * (1 << SHIFT_INT8)) + ) + + activation_min = node.meta["output_qparams"][0].qmin + activation_max = node.meta["output_qparams"][0].qmax + + args = ( + node.args[0], + zero_point1, + input1_mult, + input1_shift, + node.args[1], + zero_point2, + input2_mult, + input2_shift, + output_zero_point, + output_mult, + output_shift, + activation_min, + activation_max, + ) + return DialectNodeSpec(exir_ops.edge.cortex_m.quantized_add.default, args) + + +@AtenToCortexMPass.register_dialect_substitution(exir_ops.edge.aten.mul.Tensor) +def _get_mul_replacement( + node: Node, dialect_pass: AtenToDialectPass +) -> DialectNodeSpec | None: + del dialect_pass + if not _has_qparams(node): + return None + + scale1 = node.meta["input_qparams"][0].scale + zero_point1 = node.meta["input_qparams"][0].zp + scale2 = node.meta["input_qparams"][1].scale + zero_point2 = node.meta["input_qparams"][1].zp + output_scale = node.meta["output_qparams"][0].scale + output_zero_point = node.meta["output_qparams"][0].zp + + output_mult, output_shift = quantize_multiplier_aot( + (scale1 * scale2) / output_scale + ) + args = ( + node.args[0], + zero_point1, + node.args[1], + zero_point2, + output_zero_point, + output_mult, + output_shift, + ) + return DialectNodeSpec(exir_ops.edge.cortex_m.quantized_mul.default, args) + + +@AtenToCortexMPass.register_dialect_substitution(exir_ops.edge.aten._softmax.default) +def _get_softmax_replacement( + node: Node, dialect_pass: AtenToDialectPass +) -> DialectNodeSpec | None: + del dialect_pass + if not _has_qparams(node): + return None + + half_to_float = node.args[2] if len(node.args) > 2 else False + if cast(bool, half_to_float): + return None + + input_qparams = node.meta["input_qparams"][0] + output_qparams = node.meta["output_qparams"][0] + + input_multiplier, input_shift, diff_min = _compute_softmax_params( + float(input_qparams.scale) + ) + + output_scale_attr = getattr(output_qparams, "scale", None) + output_zp_attr = getattr(output_qparams, "zp", None) + if output_scale_attr is None or output_zp_attr is None: + raise AssertionError("Softmax requires output quantization parameters.") + + output_scale_val = float(output_scale_attr) + output_zp_val = int(output_zp_attr) + if not math.isclose( + output_scale_val, CMSIS_SOFTMAX_SCALE, rel_tol=0.0, abs_tol=1e-12 + ): + raise AssertionError( + "Softmax output scale must match CMSIS (1/256). " f"Got {output_scale_val}." + ) + if output_zp_val != CMSIS_SOFTMAX_ZERO_POINT: + raise AssertionError( + "Softmax output zero-point must match CMSIS (-128). " + f"Got {output_zp_val}." + ) + + args = ( + node.args[0], + node.args[1], + int(input_qparams.zp), + output_zp_val, + input_multiplier, + input_shift, + diff_min, + ) + return DialectNodeSpec(exir_ops.edge.cortex_m.softmax.default, args) + + +@AtenToCortexMPass.register_dialect_substitution(exir_ops.edge.aten.max_pool2d.default) +def _get_max_pool2d_replacement( + node: Node, dialect_pass: AtenToDialectPass +) -> DialectNodeSpec | None: + del dialect_pass + input_qparams = node.meta.get("input_qparams", {}).get(0) + cortex_m_meta = node.meta.get("custom", {}).get("cortex_m", {}) + if input_qparams is None or cortex_m_meta.get("skip_quantized_max_pool2d", False): + return None + + input_scale = float(input_qparams.scale) + input_zero_point = int(input_qparams.zp) + + output_qparams = None + if node.meta.get("output_qparams"): + output_qparams = node.meta["output_qparams"].get(0) + + if output_qparams is not None: + if getattr(output_qparams, "per_channel", False): + return None + output_scale = float(output_qparams.scale) + output_zero_point = int(output_qparams.zp) + activation_min = int(output_qparams.qmin) + activation_max = int(output_qparams.qmax) + if abs(input_scale - output_scale) > 1e-6: + return None + if input_zero_point != output_zero_point: + return None + else: + output_zero_point = input_zero_point + activation_min = torch.iinfo(torch.int8).min + activation_max = torch.iinfo(torch.int8).max + + kernel_size = _to_int_pair(node.args[1], None) + stride_arg = node.args[2] if len(node.args) > 2 else None + stride = _to_int_pair(stride_arg, kernel_size) + padding_arg = node.args[3] if len(node.args) > 3 else None + padding = _to_int_pair(padding_arg, (0, 0)) + dilation_arg = node.args[4] if len(node.args) > 4 else None + dilation = _to_int_pair(dilation_arg, (1, 1)) + ceil_mode_arg = node.args[5] if len(node.args) > 5 else False + ceil_mode = _to_bool(ceil_mode_arg, False) + + if dilation != (1, 1) or ceil_mode: + return None + + quantized_op = getattr(exir_ops.edge.cortex_m, "quantized_max_pool2d", None) + if quantized_op is None: + return None + + args = ( + node.args[0], + kernel_size, + stride, + padding, + dilation, + ceil_mode, + input_zero_point, + output_zero_point, + activation_min, + activation_max, + ) + return DialectNodeSpec(quantized_op.default, args) + + +@AtenToCortexMPass.register_dialect_substitution(exir_ops.edge.aten.minimum.default) +def _get_minimum_replacement( + node: Node, dialect_pass: AtenToDialectPass +) -> DialectNodeSpec | None: + del dialect_pass + input_tensor = _get_input_tensor_data(node) + if input_tensor.dtype not in (torch.int8, torch.int32): + return None + return DialectNodeSpec(exir_ops.edge.cortex_m.minimum.default, node.args) + + +@AtenToCortexMPass.register_dialect_substitution(exir_ops.edge.aten.maximum.default) +def _get_maximum_replacement( + node: Node, dialect_pass: AtenToDialectPass +) -> DialectNodeSpec | None: + del dialect_pass + input_tensor = _get_input_tensor_data(node) + if input_tensor.dtype != torch.int8: + return None + return DialectNodeSpec(exir_ops.edge.cortex_m.maximum.default, node.args) + + +@AtenToCortexMPass.register_dialect_substitution( + exir_ops.edge.aten.permute_copy.default +) +def _get_permute_replacement( + node: Node, dialect_pass: AtenToDialectPass +) -> DialectNodeSpec | None: + del dialect_pass + input_tensor = _get_input_tensor_data(node) + if input_tensor.dtype != torch.int8: + return None + + rank = len(input_tensor.shape) + perms = [p % rank for p in cast(tuple[int, ...], node.args[1])] + return DialectNodeSpec( + exir_ops.edge.cortex_m.transpose.default, (node.args[0], perms) + ) + + +@AtenToCortexMPass.register_dialect_substitution( + exir_ops.edge.aten.constant_pad_nd.default +) +def _get_pad_replacement( + node: Node, dialect_pass: AtenToDialectPass +) -> DialectNodeSpec | None: + del dialect_pass + input_qparams = node.meta.get("input_qparams", {}) + if not input_qparams: + return None + + scale = float(input_qparams[0].scale) + zero_point = int(input_qparams[0].zp) + padding = cast(tuple[int, ...], node.args[1]) + pad_value_raw = node.args[2] if len(node.args) > 2 else 0 + pad_value_float = float(cast(float, pad_value_raw)) + + quantized_pad_value = int( + quantize_val(pad_value_float, scale, zero_point, -128, 127) + ) + + input_tensor = _get_input_tensor_data(node) + rank = len(input_tensor.shape) + assert 1 <= rank <= 4, f"cortex_m pad: expected rank in [1, 4], got {rank}" + n_pairs = len(padding) // 2 + assert ( + len(padding) % 2 == 0 and n_pairs <= rank + ), f"cortex_m pad: invalid padding length {len(padding)} for rank {rank}" + + pre_pad = [0, 0, 0, 0] + post_pad = [0, 0, 0, 0] + for i in range(n_pairs): + dim_4d = 3 - i + pre_pad[dim_4d] = int(padding[2 * i]) + post_pad[dim_4d] = int(padding[2 * i + 1]) + + pre_pad = to_physical_order(pre_pad, input_tensor) + post_pad = to_physical_order(post_pad, input_tensor) + + args = (node.args[0], pre_pad, post_pad, int(quantized_pad_value)) + return DialectNodeSpec(exir_ops.edge.cortex_m.pad.default, args) diff --git a/backends/cortex_m/passes/cortex_m_pass_manager.py b/backends/cortex_m/passes/cortex_m_pass_manager.py index abd086c0505..ede60fbcbee 100644 --- a/backends/cortex_m/passes/cortex_m_pass_manager.py +++ b/backends/cortex_m/passes/cortex_m_pass_manager.py @@ -28,7 +28,6 @@ from .decompose_hardswish_pass import DecomposeHardswishPass from .decompose_mean_pass import DecomposeMeanPass from .quantized_clamp_activation_pass import QuantizedClampActivationPass -from .quantized_op_fusion_pass import QuantizedOpFusionPass from .replace_quant_nodes_pass import ReplaceQuantNodesPass PassClass = Type[ExportPass] @@ -44,7 +43,6 @@ class CortexMPassManager(PassManager): ActivationFusionPass, QuantizedClampActivationPass, DecomposeHardswishPass, - QuantizedOpFusionPass, AtenToCortexMPass, ] diff --git a/backends/cortex_m/passes/quantized_op_fusion_pass.py b/backends/cortex_m/passes/quantized_op_fusion_pass.py deleted file mode 100644 index 5072a67f0ed..00000000000 --- a/backends/cortex_m/passes/quantized_op_fusion_pass.py +++ /dev/null @@ -1,368 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# Copyright 2025-2026 Arm Limited and/or its affiliates. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -import math -from typing import cast, Dict, Optional - -import torch -from executorch.backends.cortex_m.passes.passes_utils import ( - quantize_multiplier_aot, - quantize_val, - SHIFT_INT8, - to_physical_order, -) -from executorch.backends.cortex_m.quantizer.quantization_configs import ( - CMSIS_SOFTMAX_SCALE, - CMSIS_SOFTMAX_ZERO_POINT, -) -from executorch.exir.dialects._ops import ops as exir_ops -from executorch.exir.dialects.edge._ops import EdgeOpOverload -from executorch.exir.pass_base import ExportPass, NodeMetadata, ProxyValue -from torch.fx.node import Argument - - -class QuantizedOpFusionPass(ExportPass): - """ - Generic ExportPass that: - 1. Replaces certain ops with cortex_m variants based on qualifiers. - 2. Fuses patterns: dequantize_per_tensor -> [binary_op] -> quantize_per_tensor - into cortex_m.quantized_[op].default with AoT computed multipliers/shifts. - - - Supports multiple binary operations with backward compatibility for add. - """ - - _SOFTMAX_INPUT_INTEGER_BITS = 5 - - def _get_add_replacement(self, args, meta): - if ( - meta.data.get("input_qparams", {}) == {} - or meta.data.get("output_qparams", {}) == {} - ): - return exir_ops.edge.aten.add.Tensor, args - - # Extract values - scale1 = meta["input_qparams"][0].scale - zero_point1 = meta["input_qparams"][0].zp - scale2 = meta["input_qparams"][1].scale - zero_point2 = meta["input_qparams"][1].zp - output_scale = meta["output_qparams"][0].scale - output_zero_point = meta["output_qparams"][0].zp - - # AoT COMPUTATION: Calculate multipliers and shifts - max_scale_2x = 2 * max(scale1, scale2) - - input1_mult, input1_shift = quantize_multiplier_aot(scale1 / max_scale_2x) - input2_mult, input2_shift = quantize_multiplier_aot(scale2 / max_scale_2x) - output_mult, output_shift = quantize_multiplier_aot( - max_scale_2x / (output_scale * (1 << SHIFT_INT8)) - ) - - activation_min = meta["output_qparams"][0].qmin - activation_max = meta["output_qparams"][0].qmax - - args = ( - args[0], - zero_point1, - input1_mult, - input1_shift, - args[1], - zero_point2, - input2_mult, - input2_shift, - output_zero_point, - output_mult, - output_shift, - activation_min, - activation_max, - ) - - return exir_ops.edge.cortex_m.quantized_add.default, args - - def _get_mul_replacement(self, args, meta): - if ( - meta.data.get("input_qparams", {}) == {} - or meta.data.get("output_qparams", {}) == {} - ): - return exir_ops.edge.aten.mul.Tensor, args - - # Extract values - scale1 = meta["input_qparams"][0].scale - zero_point1 = meta["input_qparams"][0].zp - scale2 = meta["input_qparams"][1].scale - zero_point2 = meta["input_qparams"][1].zp - output_scale = meta["output_qparams"][0].scale - output_zero_point = meta["output_qparams"][0].zp - - scale_factor = (scale1 * scale2) / output_scale - output_mult, output_shift = quantize_multiplier_aot(scale_factor) - - args = ( - args[0], - zero_point1, - args[1], - zero_point2, - output_zero_point, - output_mult, - output_shift, - ) - - return exir_ops.edge.cortex_m.quantized_mul.default, args - - def _compute_softmax_params(self, input_scale: float) -> tuple[int, int, int]: - """ - Convert the incoming per-tensor input scale into the CMSIS fixed-point - parameters expected by `arm_softmax_s8`. - - 1. Clamp the real multiplier to the Q31 range using the fixed number of - input integer bits mandated by CMSIS. - 2. Feed that multiplier through `quantize_multiplier_aot` to get the - (multiplier, shift) pair arm_softmax_s8 expects. - 3. Derive `diff_min`, the CMSIS threshold for early bailout when - differences saturate, using the same multiplier/shift values. - """ - real_multiplier = min( - input_scale * (1 << (31 - self._SOFTMAX_INPUT_INTEGER_BITS)), - float((1 << 31) - 1), - ) - input_multiplier, input_shift = quantize_multiplier_aot(real_multiplier) - diff_min_term = ( - ((1 << self._SOFTMAX_INPUT_INTEGER_BITS) - 1) - * math.ldexp(1.0, 31 - self._SOFTMAX_INPUT_INTEGER_BITS) - / math.ldexp(1.0, input_shift) - ) - diff_min = -int(math.floor(diff_min_term)) - return int(input_multiplier), int(input_shift), diff_min - - def _get_softmax_replacement(self, args, meta): - if ( - meta.data.get("input_qparams", {}) == {} - or meta.data.get("output_qparams", {}) == {} - ): - return exir_ops.edge.aten._softmax.default, args - - input_qparams = meta["input_qparams"][0] - output_qparams = meta["output_qparams"][0] - - half_to_float = args[2] if len(args) > 2 else False - if half_to_float: - return exir_ops.edge.aten._softmax.default, args - - input_multiplier, input_shift, diff_min = self._compute_softmax_params( - float(input_qparams.scale) - ) - - output_scale_attr = getattr(output_qparams, "scale", None) - output_zp_attr = getattr(output_qparams, "zp", None) - if output_scale_attr is None or output_zp_attr is None: - raise AssertionError("Softmax requires output quantization parameters.") - - output_scale_val = float(output_scale_attr) - output_zp_val = int(output_zp_attr) - if not math.isclose( - output_scale_val, CMSIS_SOFTMAX_SCALE, rel_tol=0.0, abs_tol=1e-12 - ): - raise AssertionError( - "Softmax output scale must match CMSIS (1/256). " - f"Got {output_scale_val}." - ) - if output_zp_val != CMSIS_SOFTMAX_ZERO_POINT: - raise AssertionError( - "Softmax output zero-point must match CMSIS (-128). " - f"Got {output_zp_val}." - ) - - new_args = ( - args[0], - args[1], - int(input_qparams.zp), - output_zp_val, - input_multiplier, - input_shift, - diff_min, - ) - - return exir_ops.edge.cortex_m.softmax.default, new_args - - def _to_int_pair( - self, value: Argument, default: Optional[tuple[int, int]] - ) -> tuple[int, int]: - if value is None: - assert default is not None, "Expected default sequence for normalization" - return (default[0], default[1]) - - try: - int_pair = cast(tuple[int, int], value) - return int_pair - except Exception: - raise ValueError(f"Expected a tuple of two integers, got {value}") - - def _unwrap_argument(self, arg: Argument) -> Argument: - if isinstance(arg, ProxyValue): - return arg.data - return arg - - def _to_bool(self, value: Argument, default: bool) -> bool: - if value is None: - return default - try: - bool_value = cast(bool, value) - return bool_value - except Exception: - raise ValueError(f"Expected a boolean value, got {value}") - - def _get_max_pool2d_replacement(self, args, meta): - input_qparams = meta["input_qparams"].get(0) - cortex_m_meta = meta.data.get("custom", {}).get("cortex_m", {}) - if input_qparams is None or cortex_m_meta.get( - "skip_quantized_max_pool2d", False - ): - return exir_ops.edge.aten.max_pool2d.default, args - - input_scale = float(input_qparams.scale) - input_zero_point = int(input_qparams.zp) - - output_qparams = None - if meta.data.get("output_qparams"): - output_qparams = meta["output_qparams"].get(0) - - if output_qparams is not None: - if getattr(output_qparams, "per_channel", False): - return exir_ops.edge.aten.max_pool2d.default, args - output_scale = float(output_qparams.scale) - output_zero_point = int(output_qparams.zp) - activation_min = int(output_qparams.qmin) - activation_max = int(output_qparams.qmax) - if abs(input_scale - output_scale) > 1e-6: - return exir_ops.edge.aten.max_pool2d.default, args - if input_zero_point != output_zero_point: - return exir_ops.edge.aten.max_pool2d.default, args - else: - output_zero_point = input_zero_point - activation_min = torch.iinfo(torch.int8).min - activation_max = torch.iinfo(torch.int8).max - - kernel_size = self._to_int_pair(args[1], None) - stride_arg = args[2] if len(args) > 2 else None - stride = self._to_int_pair(stride_arg, kernel_size) - padding_arg = args[3] if len(args) > 3 else None - padding = self._to_int_pair(padding_arg, (0, 0)) - dilation_arg = args[4] if len(args) > 4 else None - dilation = self._to_int_pair(dilation_arg, (1, 1)) - - ceil_mode_arg = args[5] if len(args) > 5 else False - ceil_mode = self._to_bool(ceil_mode_arg, False) - - if dilation != (1, 1) or ceil_mode: - return exir_ops.edge.aten.max_pool2d.default, args - - quantized_op = getattr(exir_ops.edge.cortex_m, "quantized_max_pool2d", None) - if quantized_op is None: - return exir_ops.edge.aten.max_pool2d.default, args - - new_args = ( - args[0], - kernel_size, - stride, - padding, - dilation, - ceil_mode, - input_zero_point, - output_zero_point, - activation_min, - activation_max, - ) - - return quantized_op.default, new_args - - def _get_minimum_replacement(self, args, meta): - if args[0].data.dtype not in (torch.int8, torch.int32): - return exir_ops.edge.aten.minimum.default, args - - return exir_ops.edge.cortex_m.minimum.default, args - - def _get_maximum_replacement(self, args, meta): - if args[0].data.dtype != torch.int8: - return exir_ops.edge.aten.maximum.default, args - - return exir_ops.edge.cortex_m.maximum.default, args - - def _get_permute_replacement(self, args, meta): - if args[0].data.dtype != torch.int8: - return exir_ops.edge.aten.permute_copy.default, args - - rank = len(args[0].data.shape) - perms = [p % rank for p in args[1]] - args = (args[0], perms) - return exir_ops.edge.cortex_m.transpose.default, args - - def _get_pad_replacement(self, args, meta): - input_qparams = meta.data.get("input_qparams", {}) - if not input_qparams: - return exir_ops.edge.aten.constant_pad_nd.default, args - - scale = float(input_qparams[0].scale) - zero_point = int(input_qparams[0].zp) - - padding = self._unwrap_argument(args[1]) - pad_value_raw = self._unwrap_argument(args[2]) if len(args) > 2 else 0 - pad_value_float = float(pad_value_raw) - - quantized_pad_value = int( - quantize_val(pad_value_float, scale, zero_point, -128, 127) - ) - - rank = len(args[0].data.shape) - assert 1 <= rank <= 4, f"cortex_m pad: expected rank in [1, 4], got {rank}" - n_pairs = len(padding) // 2 - assert ( - len(padding) % 2 == 0 and n_pairs <= rank - ), f"cortex_m pad: invalid padding length {len(padding)} for rank {rank}" - - pre_pad = [0, 0, 0, 0] - post_pad = [0, 0, 0, 0] - for i in range(n_pairs): - dim_4d = 3 - i - pre_pad[dim_4d] = int(padding[2 * i]) - post_pad[dim_4d] = int(padding[2 * i + 1]) - - pre_pad = to_physical_order(pre_pad, args[0].data) - post_pad = to_physical_order(post_pad, args[0].data) - - new_args = (args[0], pre_pad, post_pad, int(quantized_pad_value)) - return exir_ops.edge.cortex_m.pad.default, new_args - - def call_operator( - self, - op: EdgeOpOverload, - args: tuple[Argument, ...], - kwargs: Dict[str, Argument], - meta: NodeMetadata, - ) -> ProxyValue: - - match op: - case exir_ops.edge.aten.add.Tensor: - op, args = self._get_add_replacement(args, meta) - case exir_ops.edge.aten.mul.Tensor: - op, args = self._get_mul_replacement(args, meta) - case exir_ops.edge.aten._softmax.default: - op, args = self._get_softmax_replacement(args, meta) - case exir_ops.edge.aten.max_pool2d.default: - op, args = self._get_max_pool2d_replacement(args, meta) - case exir_ops.edge.aten.minimum.default: - op, args = self._get_minimum_replacement(args, meta) - case exir_ops.edge.aten.maximum.default: - op, args = self._get_maximum_replacement(args, meta) - case exir_ops.edge.aten.permute_copy.default: - op, args = self._get_permute_replacement(args, meta) - case exir_ops.edge.aten.constant_pad_nd.default: - op, args = self._get_pad_replacement(args, meta) - case _: - pass - - result = super().call_operator(op, args, {}, meta) - return result