diff --git a/backends/arm/_passes/__init__.py b/backends/arm/_passes/__init__.py index 174a1960aab..6b39d9a25c3 100644 --- a/backends/arm/_passes/__init__.py +++ b/backends/arm/_passes/__init__.py @@ -140,6 +140,9 @@ from .remove_getitem_pass import RemoveGetItemPass # noqa from .remove_graph_asserts_pass import RemoveGraphAssertsPass # noqa from .remove_noop_pass import RemoveNoopPass # noqa +from .remove_permutes_around_elementwise_tosa_ops import ( # noqa + RemovePermutesAroundElementwiseTosaOps, +) from .replace_scalar_with_tensor_pass import ( # noqa ReplaceScalarWithTensorByProfilePass, ) diff --git a/backends/arm/_passes/arm_pass_manager.py b/backends/arm/_passes/arm_pass_manager.py index ebe6c4591e6..1954d958bb3 100644 --- a/backends/arm/_passes/arm_pass_manager.py +++ b/backends/arm/_passes/arm_pass_manager.py @@ -125,6 +125,7 @@ RemoveGetItemPass, RemoveGraphAssertsPass, RemoveNoopPass, + RemovePermutesAroundElementwiseTosaOps, ReplaceInfAndLimitValuesPass, ReplaceScalarWithTensorByProfilePass, RewriteAvgPool2dPass, @@ -164,9 +165,6 @@ PostponePermuteOpBelowSqueezeOrUnsqueezeLikeView, ) -from executorch.backends.transforms.remove_permutes_around_elementwise_ops import ( - RemovePermutesAroundElementwiseOps, -) from executorch.exir import ExportedProgram from executorch.exir.pass_base import ExportPass from executorch.exir.pass_manager import PassManager @@ -538,7 +536,7 @@ def _tosa_pipeline( RewriteMatmulPass(), RewritePadPass(), FuseViewCopyTransformPass(), - RemovePermutesAroundElementwiseOps(), + RemovePermutesAroundElementwiseTosaOps(), PostponePermuteOpBelowSqueezeOrUnsqueezeLikeView(), FuseCascadedTransposeOrPermuteOps(), ConvertPermuteSingletonToViewPass(), diff --git a/backends/arm/_passes/remove_permutes_around_elementwise_tosa_ops.py b/backends/arm/_passes/remove_permutes_around_elementwise_tosa_ops.py new file mode 100644 index 00000000000..bc03ebacd81 --- /dev/null +++ b/backends/arm/_passes/remove_permutes_around_elementwise_tosa_ops.py @@ -0,0 +1,17 @@ +# Copyright 2026 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from executorch.backends.transforms.remove_permutes_around_elementwise_ops import ( + RemovePermutesAroundElementwiseOps, +) +from executorch.exir.dialects._ops import ops as exir_ops + + +class RemovePermutesAroundElementwiseTosaOps(RemovePermutesAroundElementwiseOps): + permutable_ops = { + *RemovePermutesAroundElementwiseOps.permutable_ops, + exir_ops.backend.tosa.RESCALE.default, + exir_ops.backend.tosa.TABLE.default, + } diff --git a/backends/arm/test/passes/test_remove_permutes_around_elementwise_tosa_ops.py b/backends/arm/test/passes/test_remove_permutes_around_elementwise_tosa_ops.py new file mode 100644 index 00000000000..341d985134e --- /dev/null +++ b/backends/arm/test/passes/test_remove_permutes_around_elementwise_tosa_ops.py @@ -0,0 +1,59 @@ +# Copyright 2026 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from executorch.backends.arm._passes.remove_permutes_around_elementwise_tosa_ops import ( + RemovePermutesAroundElementwiseTosaOps, +) +from executorch.backends.arm.tosa.specification import ( + TosaLoweringContext, + TosaSpecification, +) +from executorch.exir.dialects._ops import ops as exir_ops + +TOSA_INT_SPEC = TosaSpecification.create_from_string("TOSA-1.0+INT") +PERMUTE_TARGET = exir_ops.edge.aten.permute_copy.default +RESCALE_TARGET = exir_ops.backend.tosa.RESCALE.default +TABLE_TARGET = exir_ops.backend.tosa.TABLE.default + + +def _count_nodes(graph_module: torch.fx.GraphModule, target) -> int: + return sum( + 1 + for node in graph_module.graph.nodes + if node.op == "call_function" and node.target == target + ) + + +def test_remove_permutes_around_rescale_tosa_INT() -> None: + graph = torch.fx.Graph() + x = graph.placeholder("x") + x.meta["val"] = torch.randn(1, 3, 4, 5) + + permute_in = graph.create_node( + "call_function", + PERMUTE_TARGET, + args=(x, [0, 2, 3, 1]), + ) + rescale = graph.create_node( + "call_function", + RESCALE_TARGET, + args=(permute_in, torch.int8, [1.0], 0, 0), + ) + permute_out = graph.create_node( + "call_function", + PERMUTE_TARGET, + args=(rescale, [0, 3, 1, 2]), + ) + graph.output(permute_out) + + graph_module = torch.fx.GraphModule({}, graph) + + with TosaLoweringContext(TOSA_INT_SPEC): + result = RemovePermutesAroundElementwiseTosaOps().call(graph_module) + + assert result.modified + assert _count_nodes(result.graph_module, PERMUTE_TARGET) == 0 + assert _count_nodes(result.graph_module, RESCALE_TARGET) == 1