From 21f192378e50edc6ecdadd91f3e8ec2eac48671f Mon Sep 17 00:00:00 2001 From: Adrian Lundell Date: Tue, 17 Mar 2026 14:28:50 +0100 Subject: [PATCH] Arm backend: Generalize fuse_view_copy_transform_pass Update the fuse_view_copy_transform_pass to check shapes rather than args to match differet ways of expressing the same shape. This change makes some as_strided ops into noops which they were not previosuly, so an addtional noops check is added to fix this. Additionally moves constants in the pass to class level to simplify overriding behaviour for inheriting passes. Note: test names as_strided_copy -> as_strided since the _copy suffix of operator names are removed in name check. Signed-off-by: Adrian Lundell Change-Id: Iba8dc3862c3fc4a8a34e036377ec7bcee84988b1 --- backends/arm/test/ops/test_as_strided_copy.py | 13 +- backends/arm/tosa/partitioner.py | 14 ++ backends/transforms/fuse_view_copy.py | 178 +++++++++--------- 3 files changed, 111 insertions(+), 94 deletions(-) diff --git a/backends/arm/test/ops/test_as_strided_copy.py b/backends/arm/test/ops/test_as_strided_copy.py index 9ed0b52da49..55c72a7fc54 100644 --- a/backends/arm/test/ops/test_as_strided_copy.py +++ b/backends/arm/test/ops/test_as_strided_copy.py @@ -52,7 +52,7 @@ def _make_case( delegated_cases = { "reshape_2d": lambda: _make_case((4, 6), (3, 8)), "flatten": lambda: _make_case((2, 3, 4), (6, 4)), - "expand_rank": lambda: _make_case((2, 3, 4), (2, 3, 4)), + "expand_rank": lambda: _make_case((2, 3, 4), (1, 2, 3, 4)), } unsupported_cases = { @@ -67,11 +67,12 @@ def _make_case( contiguous_strides((4, 4)), 4, ), + "noop": lambda: _make_case((2, 3, 4), (2, 3, 4)), # Single noop is not delegated } @common.parametrize("test_data", delegated_cases) -def test_as_strided_copy_tosa_FP(test_data): +def test_as_strided_tosa_FP(test_data): tensor, size, stride = test_data() module = AsStridedCopyModule(size, stride) pipeline = TosaPipelineFP[input_t]( @@ -83,7 +84,7 @@ def test_as_strided_copy_tosa_FP(test_data): @common.parametrize("test_data", delegated_cases) -def test_as_strided_copy_tosa_INT(test_data): +def test_as_strided_tosa_INT(test_data): tensor, size, stride = test_data() module = AsStridedCopyModule(size, stride) pipeline = TosaPipelineINT[input_t]( @@ -96,7 +97,7 @@ def test_as_strided_copy_tosa_INT(test_data): @common.parametrize("test_data", delegated_cases) @common.SkipIfNoModelConverter -def test_as_strided_copy_vgf_no_quant(test_data): +def test_as_strided_vgf_no_quant(test_data): tensor, size, stride = test_data() module = AsStridedCopyModule(size, stride) pipeline = VgfPipeline[input_t]( @@ -111,7 +112,7 @@ def test_as_strided_copy_vgf_no_quant(test_data): @common.parametrize("test_data", delegated_cases) @common.SkipIfNoModelConverter -def test_as_strided_copy_vgf_quant(test_data): +def test_as_strided_vgf_quant(test_data): tensor, size, stride = test_data() module = AsStridedCopyModule(size, stride) pipeline = VgfPipeline[input_t]( @@ -124,7 +125,7 @@ def test_as_strided_copy_vgf_quant(test_data): @common.parametrize("test_data", unsupported_cases) -def test_as_strided_copy_not_delegated(test_data): +def test_as_strided_no_target_not_delegated(test_data): tensor, size, stride, *rest = test_data() storage_offset = rest[0] if rest else 0 module = AsStridedCopyModule(size, stride, storage_offset=storage_offset) diff --git a/backends/arm/tosa/partitioner.py b/backends/arm/tosa/partitioner.py index dba3c5287d3..29abefbf2ac 100644 --- a/backends/arm/tosa/partitioner.py +++ b/backends/arm/tosa/partitioner.py @@ -60,6 +60,19 @@ def _is_noop_detach_copy(node: torch.fx.Node) -> bool: return node.target == exir_ops.edge.aten.detach_copy.default +def _is_noop_as_strided_copy(node: torch.fx.Node) -> bool: + if node.target != exir_ops.edge.aten.as_strided_copy.default: + return False + else: + input_tensor = get_first_fake_tensor(ensure_type(torch.fx.Node, node.args[0])) + output_tensor = get_first_fake_tensor(node) + return ( + input_tensor.shape == output_tensor.shape + and input_tensor.stride() == output_tensor.stride() + and input_tensor.storage_offset() == output_tensor.storage_offset() + ) + + def _is_noop_to_dim_order_copy(node: torch.fx.node.Node) -> bool: if node.target != exir_ops.edge.dim_order_ops._to_dim_order_copy.default: return False @@ -263,6 +276,7 @@ def _tag_module( # noqa or _is_noop_detach_copy(node) or _is_noop_to_dim_order_copy(node) or _is_view_copy(node) + or _is_noop_as_strided_copy(node) or node.target in Q_OPS or node.target in DQ_OPS for node in partition.nodes diff --git a/backends/transforms/fuse_view_copy.py b/backends/transforms/fuse_view_copy.py index b7c52f95fa3..4453015abf3 100644 --- a/backends/transforms/fuse_view_copy.py +++ b/backends/transforms/fuse_view_copy.py @@ -1,6 +1,6 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. -# Copyright 2025 Arm Limited and/or its affiliates. +# Copyright 2025-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -14,101 +14,103 @@ from executorch.exir.pass_base import ExportPass, PassResult -UNARY_ELEMENTWISE_OPS = [ - exir_ops.edge.aten.view_copy.default, - exir_ops.edge.aten.alias_copy.default, - exir_ops.edge.aten.clone.default, - exir_ops.edge.dim_order_ops._clone_dim_order.default, - exir_ops.edge.aten._to_copy.default, - exir_ops.edge.dim_order_ops._to_dim_order_copy.default, - exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, - exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, - exir_ops.edge.aten.abs.default, - exir_ops.edge.aten.clamp.default, - exir_ops.edge.aten.ceil.default, - exir_ops.edge.aten.floor.default, - exir_ops.edge.aten.neg.default, - exir_ops.edge.aten.relu.default, - exir_ops.edge.aten.round.default, - exir_ops.edge.aten.sigmoid.default, - exir_ops.edge.aten.silu.default, - exir_ops.edge.aten.sqrt.default, - exir_ops.edge.aten.tanh.default, - exir_ops.edge.aten.sign.default, - exir_ops.edge.aten.reciprocal.default, - exir_ops.edge.aten.gelu.default, - exir_ops.edge.aten.rsqrt.default, - exir_ops.edge.aten.exp.default, - exir_ops.edge.aten.log.default, -] - - -def merge_view_copy_chains(graph: torch.fx.Graph) -> tuple[torch.fx.Graph, bool]: - """ - Find chains of view_copy nodes and unary elementwise ops and set all - view_copy nodes to have the final shape. The views will then be removed - by the remove_noop_view_copy call. - - Only merges view_copy nodes that are not used by any other nodes. - """ - ops = exir_ops.edge - view_op = ops.aten.view_copy.default - modified = False - for node in graph.nodes: - if node.op == "call_function" and node.target == view_op: - # Find a chain of unary elementwise ops and save all view_copy nodes - end_node = node - view_ops = [node] - while ( - end_node.op == "call_function" - and end_node.target in UNARY_ELEMENTWISE_OPS - and len(end_node.users) == 1 - and list(end_node.users)[0].target in UNARY_ELEMENTWISE_OPS - ): - end_node = list(end_node.users)[0] - if end_node.target == view_op: - view_ops.append(end_node) - - # Set all view_copy nodes to have the final shape - if len(view_ops) > 1: - final_shape = view_ops[-1].args[1] - for node in view_ops: - new_args = (node.args[0], final_shape) - node.args = new_args - modified = True - - graph.eliminate_dead_code() - return graph, modified - - -def remove_noop_view_copy(graph: torch.fx.Graph) -> tuple[torch.fx.Graph, bool]: - """ - Remove view_copy nodes that are no-ops. - """ - ops = exir_ops.edge - view_op = ops.aten.view_copy.default - modified = False - for node in graph.nodes: - if node.op == "call_function" and node.target == view_op: - input_shape = list(node.args[0].meta["val"].shape) - target_shape = node.args[1] - if input_shape == target_shape: - node.replace_all_uses_with(node.args[0]) - modified = True - graph.eliminate_dead_code() - return graph, modified - - class FuseViewCopyTransform(ExportPass): _passes_required_after: Set[Type[ExportPass]] = set() + VIEW_OP = exir_ops.edge.aten.view_copy.default + + UNARY_ELEMENTWISE_OPS = [ + exir_ops.edge.aten.alias_copy.default, + exir_ops.edge.aten.clone.default, + exir_ops.edge.dim_order_ops._clone_dim_order.default, + exir_ops.edge.aten._to_copy.default, + exir_ops.edge.dim_order_ops._to_dim_order_copy.default, + exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, + exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, + exir_ops.edge.aten.abs.default, + exir_ops.edge.aten.clamp.default, + exir_ops.edge.aten.ceil.default, + exir_ops.edge.aten.floor.default, + exir_ops.edge.aten.neg.default, + exir_ops.edge.aten.relu.default, + exir_ops.edge.aten.round.default, + exir_ops.edge.aten.sigmoid.default, + exir_ops.edge.aten.silu.default, + exir_ops.edge.aten.sqrt.default, + exir_ops.edge.aten.tanh.default, + exir_ops.edge.aten.sign.default, + exir_ops.edge.aten.reciprocal.default, + exir_ops.edge.aten.gelu.default, + exir_ops.edge.aten.rsqrt.default, + exir_ops.edge.aten.exp.default, + exir_ops.edge.aten.log.default, + ] + + def merge_view_copy_chains( + self, graph: torch.fx.Graph + ) -> tuple[torch.fx.Graph, bool]: + """ + Find chains of view_copy nodes and unary elementwise ops and set all + view_copy nodes to have the final shape. The views will then be removed + by the remove_noop_view_copy call. + + Only merges view_copy nodes that are not used by any other nodes. + """ + view_op = self.VIEW_OP + modified = False + ops = self.UNARY_ELEMENTWISE_OPS + [view_op] + for node in graph.nodes: + if node.op == "call_function" and node.target == view_op: + # Find a chain of unary elementwise ops and save all view_copy nodes + end_node = node + view_ops = [node] + while ( + end_node.op == "call_function" + and end_node.target in ops + and len(end_node.users) == 1 + and list(end_node.users)[0].target in ops + ): + end_node = list(end_node.users)[0] + if end_node.target == view_op: + view_ops.append(end_node) + + # Set all view_copy nodes to have the final shape + if len(view_ops) > 1: + final_shape = view_ops[-1].args[1] + for node in view_ops: + new_args = (node.args[0], final_shape) + node.args = new_args + modified = True + if modified: + graph.eliminate_dead_code() + return graph, modified + + def remove_noop_view_copy( + self, graph: torch.fx.Graph + ) -> tuple[torch.fx.Graph, bool]: + """ + Remove view_copy nodes that are no-ops. + """ + view_op = self.VIEW_OP + modified = False + for node in graph.nodes: + if node.op == "call_function" and node.target == view_op: + input_shape = list(node.args[0].meta["val"].shape) + target_shape = list(node.meta["val"].shape) + if input_shape == target_shape: + node.replace_all_uses_with(node.args[0]) + modified = True + if modified: + graph.eliminate_dead_code() + return graph, modified + def call(self, graph_module: torch.fx.GraphModule) -> PassResult: - graph_module.graph, modified = merge_view_copy_chains(graph_module.graph) + graph_module.graph, modified = self.merge_view_copy_chains(graph_module.graph) if modified: graph_module.recompile() graph_module = super().call(graph_module).graph_module - graph_module.graph, modified = remove_noop_view_copy(graph_module.graph) + graph_module.graph, modified = self.remove_noop_view_copy(graph_module.graph) if modified: graph_module.recompile() graph_module = super().call(graph_module).graph_module