diff --git a/backends/arm/test/ops/test_as_strided_copy.py b/backends/arm/test/ops/test_as_strided_copy.py index 9ed0b52da49..55c72a7fc54 100644 --- a/backends/arm/test/ops/test_as_strided_copy.py +++ b/backends/arm/test/ops/test_as_strided_copy.py @@ -52,7 +52,7 @@ def _make_case( delegated_cases = { "reshape_2d": lambda: _make_case((4, 6), (3, 8)), "flatten": lambda: _make_case((2, 3, 4), (6, 4)), - "expand_rank": lambda: _make_case((2, 3, 4), (2, 3, 4)), + "expand_rank": lambda: _make_case((2, 3, 4), (1, 2, 3, 4)), } unsupported_cases = { @@ -67,11 +67,12 @@ def _make_case( contiguous_strides((4, 4)), 4, ), + "noop": lambda: _make_case((2, 3, 4), (2, 3, 4)), # Single noop is not delegated } @common.parametrize("test_data", delegated_cases) -def test_as_strided_copy_tosa_FP(test_data): +def test_as_strided_tosa_FP(test_data): tensor, size, stride = test_data() module = AsStridedCopyModule(size, stride) pipeline = TosaPipelineFP[input_t]( @@ -83,7 +84,7 @@ def test_as_strided_copy_tosa_FP(test_data): @common.parametrize("test_data", delegated_cases) -def test_as_strided_copy_tosa_INT(test_data): +def test_as_strided_tosa_INT(test_data): tensor, size, stride = test_data() module = AsStridedCopyModule(size, stride) pipeline = TosaPipelineINT[input_t]( @@ -96,7 +97,7 @@ def test_as_strided_copy_tosa_INT(test_data): @common.parametrize("test_data", delegated_cases) @common.SkipIfNoModelConverter -def test_as_strided_copy_vgf_no_quant(test_data): +def test_as_strided_vgf_no_quant(test_data): tensor, size, stride = test_data() module = AsStridedCopyModule(size, stride) pipeline = VgfPipeline[input_t]( @@ -111,7 +112,7 @@ def test_as_strided_copy_vgf_no_quant(test_data): @common.parametrize("test_data", delegated_cases) @common.SkipIfNoModelConverter -def test_as_strided_copy_vgf_quant(test_data): +def test_as_strided_vgf_quant(test_data): tensor, size, stride = test_data() module = AsStridedCopyModule(size, stride) pipeline = VgfPipeline[input_t]( @@ -124,7 +125,7 @@ def test_as_strided_copy_vgf_quant(test_data): @common.parametrize("test_data", unsupported_cases) -def test_as_strided_copy_not_delegated(test_data): +def test_as_strided_no_target_not_delegated(test_data): tensor, size, stride, *rest = test_data() storage_offset = rest[0] if rest else 0 module = AsStridedCopyModule(size, stride, storage_offset=storage_offset) diff --git a/backends/arm/tosa/partitioner.py b/backends/arm/tosa/partitioner.py index dba3c5287d3..29abefbf2ac 100644 --- a/backends/arm/tosa/partitioner.py +++ b/backends/arm/tosa/partitioner.py @@ -60,6 +60,19 @@ def _is_noop_detach_copy(node: torch.fx.Node) -> bool: return node.target == exir_ops.edge.aten.detach_copy.default +def _is_noop_as_strided_copy(node: torch.fx.Node) -> bool: + if node.target != exir_ops.edge.aten.as_strided_copy.default: + return False + else: + input_tensor = get_first_fake_tensor(ensure_type(torch.fx.Node, node.args[0])) + output_tensor = get_first_fake_tensor(node) + return ( + input_tensor.shape == output_tensor.shape + and input_tensor.stride() == output_tensor.stride() + and input_tensor.storage_offset() == output_tensor.storage_offset() + ) + + def _is_noop_to_dim_order_copy(node: torch.fx.node.Node) -> bool: if node.target != exir_ops.edge.dim_order_ops._to_dim_order_copy.default: return False @@ -263,6 +276,7 @@ def _tag_module( # noqa or _is_noop_detach_copy(node) or _is_noop_to_dim_order_copy(node) or _is_view_copy(node) + or _is_noop_as_strided_copy(node) or node.target in Q_OPS or node.target in DQ_OPS for node in partition.nodes diff --git a/backends/transforms/fuse_view_copy.py b/backends/transforms/fuse_view_copy.py index b7c52f95fa3..4453015abf3 100644 --- a/backends/transforms/fuse_view_copy.py +++ b/backends/transforms/fuse_view_copy.py @@ -1,6 +1,6 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. -# Copyright 2025 Arm Limited and/or its affiliates. +# Copyright 2025-2026 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -14,101 +14,103 @@ from executorch.exir.pass_base import ExportPass, PassResult -UNARY_ELEMENTWISE_OPS = [ - exir_ops.edge.aten.view_copy.default, - exir_ops.edge.aten.alias_copy.default, - exir_ops.edge.aten.clone.default, - exir_ops.edge.dim_order_ops._clone_dim_order.default, - exir_ops.edge.aten._to_copy.default, - exir_ops.edge.dim_order_ops._to_dim_order_copy.default, - exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, - exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, - exir_ops.edge.aten.abs.default, - exir_ops.edge.aten.clamp.default, - exir_ops.edge.aten.ceil.default, - exir_ops.edge.aten.floor.default, - exir_ops.edge.aten.neg.default, - exir_ops.edge.aten.relu.default, - exir_ops.edge.aten.round.default, - exir_ops.edge.aten.sigmoid.default, - exir_ops.edge.aten.silu.default, - exir_ops.edge.aten.sqrt.default, - exir_ops.edge.aten.tanh.default, - exir_ops.edge.aten.sign.default, - exir_ops.edge.aten.reciprocal.default, - exir_ops.edge.aten.gelu.default, - exir_ops.edge.aten.rsqrt.default, - exir_ops.edge.aten.exp.default, - exir_ops.edge.aten.log.default, -] - - -def merge_view_copy_chains(graph: torch.fx.Graph) -> tuple[torch.fx.Graph, bool]: - """ - Find chains of view_copy nodes and unary elementwise ops and set all - view_copy nodes to have the final shape. The views will then be removed - by the remove_noop_view_copy call. - - Only merges view_copy nodes that are not used by any other nodes. - """ - ops = exir_ops.edge - view_op = ops.aten.view_copy.default - modified = False - for node in graph.nodes: - if node.op == "call_function" and node.target == view_op: - # Find a chain of unary elementwise ops and save all view_copy nodes - end_node = node - view_ops = [node] - while ( - end_node.op == "call_function" - and end_node.target in UNARY_ELEMENTWISE_OPS - and len(end_node.users) == 1 - and list(end_node.users)[0].target in UNARY_ELEMENTWISE_OPS - ): - end_node = list(end_node.users)[0] - if end_node.target == view_op: - view_ops.append(end_node) - - # Set all view_copy nodes to have the final shape - if len(view_ops) > 1: - final_shape = view_ops[-1].args[1] - for node in view_ops: - new_args = (node.args[0], final_shape) - node.args = new_args - modified = True - - graph.eliminate_dead_code() - return graph, modified - - -def remove_noop_view_copy(graph: torch.fx.Graph) -> tuple[torch.fx.Graph, bool]: - """ - Remove view_copy nodes that are no-ops. - """ - ops = exir_ops.edge - view_op = ops.aten.view_copy.default - modified = False - for node in graph.nodes: - if node.op == "call_function" and node.target == view_op: - input_shape = list(node.args[0].meta["val"].shape) - target_shape = node.args[1] - if input_shape == target_shape: - node.replace_all_uses_with(node.args[0]) - modified = True - graph.eliminate_dead_code() - return graph, modified - - class FuseViewCopyTransform(ExportPass): _passes_required_after: Set[Type[ExportPass]] = set() + VIEW_OP = exir_ops.edge.aten.view_copy.default + + UNARY_ELEMENTWISE_OPS = [ + exir_ops.edge.aten.alias_copy.default, + exir_ops.edge.aten.clone.default, + exir_ops.edge.dim_order_ops._clone_dim_order.default, + exir_ops.edge.aten._to_copy.default, + exir_ops.edge.dim_order_ops._to_dim_order_copy.default, + exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, + exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default, + exir_ops.edge.aten.abs.default, + exir_ops.edge.aten.clamp.default, + exir_ops.edge.aten.ceil.default, + exir_ops.edge.aten.floor.default, + exir_ops.edge.aten.neg.default, + exir_ops.edge.aten.relu.default, + exir_ops.edge.aten.round.default, + exir_ops.edge.aten.sigmoid.default, + exir_ops.edge.aten.silu.default, + exir_ops.edge.aten.sqrt.default, + exir_ops.edge.aten.tanh.default, + exir_ops.edge.aten.sign.default, + exir_ops.edge.aten.reciprocal.default, + exir_ops.edge.aten.gelu.default, + exir_ops.edge.aten.rsqrt.default, + exir_ops.edge.aten.exp.default, + exir_ops.edge.aten.log.default, + ] + + def merge_view_copy_chains( + self, graph: torch.fx.Graph + ) -> tuple[torch.fx.Graph, bool]: + """ + Find chains of view_copy nodes and unary elementwise ops and set all + view_copy nodes to have the final shape. The views will then be removed + by the remove_noop_view_copy call. + + Only merges view_copy nodes that are not used by any other nodes. + """ + view_op = self.VIEW_OP + modified = False + ops = self.UNARY_ELEMENTWISE_OPS + [view_op] + for node in graph.nodes: + if node.op == "call_function" and node.target == view_op: + # Find a chain of unary elementwise ops and save all view_copy nodes + end_node = node + view_ops = [node] + while ( + end_node.op == "call_function" + and end_node.target in ops + and len(end_node.users) == 1 + and list(end_node.users)[0].target in ops + ): + end_node = list(end_node.users)[0] + if end_node.target == view_op: + view_ops.append(end_node) + + # Set all view_copy nodes to have the final shape + if len(view_ops) > 1: + final_shape = view_ops[-1].args[1] + for node in view_ops: + new_args = (node.args[0], final_shape) + node.args = new_args + modified = True + if modified: + graph.eliminate_dead_code() + return graph, modified + + def remove_noop_view_copy( + self, graph: torch.fx.Graph + ) -> tuple[torch.fx.Graph, bool]: + """ + Remove view_copy nodes that are no-ops. + """ + view_op = self.VIEW_OP + modified = False + for node in graph.nodes: + if node.op == "call_function" and node.target == view_op: + input_shape = list(node.args[0].meta["val"].shape) + target_shape = list(node.meta["val"].shape) + if input_shape == target_shape: + node.replace_all_uses_with(node.args[0]) + modified = True + if modified: + graph.eliminate_dead_code() + return graph, modified + def call(self, graph_module: torch.fx.GraphModule) -> PassResult: - graph_module.graph, modified = merge_view_copy_chains(graph_module.graph) + graph_module.graph, modified = self.merge_view_copy_chains(graph_module.graph) if modified: graph_module.recompile() graph_module = super().call(graph_module).graph_module - graph_module.graph, modified = remove_noop_view_copy(graph_module.graph) + graph_module.graph, modified = self.remove_noop_view_copy(graph_module.graph) if modified: graph_module.recompile() graph_module = super().call(graph_module).graph_module