diff --git a/backends/arm/_passes/arm_pass_manager.py b/backends/arm/_passes/arm_pass_manager.py index 485e01278d9..3251cde3a07 100644 --- a/backends/arm/_passes/arm_pass_manager.py +++ b/backends/arm/_passes/arm_pass_manager.py @@ -34,6 +34,7 @@ ConvertToClampPass, DecomposeAcoshPass, DecomposeAdaptiveAvgPool2dPass, + DecomposeAdaptiveMaxPool2dPass, DecomposeAddmmPass, DecomposeAddSubAlphaPass, DecomposeAnyPass, @@ -527,6 +528,7 @@ def _tosa_pipeline( [ RewriteUpsamplePass(), RewriteMaxPool2dPass(), + DecomposeAdaptiveMaxPool2dPass(), RewriteConvPass(exported_program), RewriteMatmulPass(), RewritePadPass(), diff --git a/backends/arm/operators/__init__.py b/backends/arm/operators/__init__.py index 32809eed847..9436bfe2ab3 100644 --- a/backends/arm/operators/__init__.py +++ b/backends/arm/operators/__init__.py @@ -47,6 +47,7 @@ op_tanh, op_to_dim_order_copy, op_tosa_avg_pool2d, + op_tosa_avg_pool2d_adaptive, op_tosa_conv2d, op_tosa_conv3d, op_tosa_custom, @@ -55,6 +56,7 @@ op_tosa_identity, op_tosa_matmul, op_tosa_max_pool2d, + op_tosa_max_pool2d_adaptive, op_tosa_pad, op_tosa_rescale, op_tosa_resize, diff --git a/backends/arm/operators/op_tosa_avg_pool2d_adaptive.py b/backends/arm/operators/op_tosa_avg_pool2d_adaptive.py new file mode 100644 index 00000000000..d8f20653fe7 --- /dev/null +++ b/backends/arm/operators/op_tosa_avg_pool2d_adaptive.py @@ -0,0 +1,71 @@ +# Copyright 2026 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Any, List + +import torch +import tosa_serializer as ts + +from executorch.backends.arm.operators.node_visitor import ( + NodeVisitor, + register_node_visitor, +) +from executorch.backends.arm.operators.operator_validation_utils import ( + validate_num_inputs, + validate_same_dtype, + validate_valid_dtype, +) +from executorch.backends.arm.tosa.mapping import TosaArg + + +if hasattr(ts.Op, "AVG_POOL2D_ADAPTIVE"): + + @register_node_visitor + class AvgPool2dAdaptiveVisitor(NodeVisitor): + """Visitor for lowering TOSA AVG_POOL2D_ADAPTIVE operator.""" + + target = "tosa.AVG_POOL2D_ADAPTIVE.default" + + def define_node( + self, + node: torch.fx.Node, + tosa_graph: Any, + inputs: List[TosaArg], + output: TosaArg, + ) -> None: + validate_num_inputs(self.target, inputs, [7]) + validate_same_dtype(self.target, [inputs[0], output], ts) + + input_tensor, input_zp, output_zp, kernel, stride, pad, acc_arg = inputs + + supported = [ts.DType.INT8, ts.DType.FP16, ts.DType.FP32, ts.DType.BF16] + if self.tosa_spec.support_extension("int16"): + supported.append(ts.DType.INT16) + if self.tosa_spec.support_extension("fp8e4m3"): + supported.append(ts.DType.FP8E4M3) + if self.tosa_spec.support_extension("fp8e5m2"): + supported.append(ts.DType.FP8E5M2) + validate_valid_dtype( + self.target, [input_tensor, output], supported, self.tosa_spec + ) + + attr = ts.TosaSerializerAttribute() + attr.AvgPool2dAdaptiveAttribute(acc_type=acc_arg.dtype) + + self._serialize_operator( + node, + tosa_graph, + ts.Op.AVG_POOL2D_ADAPTIVE, + [ + input_tensor.name, + input_zp.name, + output_zp.name, + kernel.name, + stride.name, + pad.name, + ], + [output.name], + attr, + ) diff --git a/backends/arm/operators/op_tosa_max_pool2d_adaptive.py b/backends/arm/operators/op_tosa_max_pool2d_adaptive.py new file mode 100644 index 00000000000..07f3dbe89cb --- /dev/null +++ b/backends/arm/operators/op_tosa_max_pool2d_adaptive.py @@ -0,0 +1,60 @@ +# Copyright 2026 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Any, List + +import torch +import tosa_serializer as ts + +from executorch.backends.arm.operators.node_visitor import ( + NodeVisitor, + register_node_visitor, +) +from executorch.backends.arm.operators.operator_validation_utils import ( + validate_num_inputs, + validate_same_dtype, + validate_valid_dtype, +) +from executorch.backends.arm.tosa.mapping import TosaArg + + +if hasattr(ts.Op, "MAX_POOL2D_ADAPTIVE"): + + @register_node_visitor + class MaxPool2dAdaptiveVisitor(NodeVisitor): + """Visitor for lowering TOSA MAX_POOL2D_ADAPTIVE operator.""" + + target = "tosa.MAX_POOL2D_ADAPTIVE.default" + + def define_node( + self, + node: torch.fx.Node, + tosa_graph: Any, + inputs: List[TosaArg], + output: TosaArg, + ) -> None: + validate_num_inputs(self.target, inputs, [4]) + validate_same_dtype(self.target, [inputs[0], output], ts) + + input_tensor, kernel, stride, pad = inputs + + supported = [ts.DType.INT8, ts.DType.FP16, ts.DType.FP32, ts.DType.BF16] + if self.tosa_spec.support_extension("int16"): + supported.append(ts.DType.INT16) + validate_valid_dtype( + self.target, [input_tensor, output], supported, self.tosa_spec + ) + + attr = ts.TosaSerializerAttribute() + attr.MaxPool2dAdaptiveAttribute(nan_mode=ts.NanPropagationMode.PROPAGATE) + + self._serialize_operator( + node, + tosa_graph, + ts.Op.MAX_POOL2D_ADAPTIVE, + [input_tensor.name, kernel.name, stride.name, pad.name], + [output.name], + attr, + )