From 44adb391f517fbcc33699921e09a1606b971ba03 Mon Sep 17 00:00:00 2001 From: Abdurrahman Akkas Date: Thu, 19 Mar 2026 11:00:38 -0700 Subject: [PATCH] Generalize and simplify permute fusion. (#18254) Summary: The old algorithm can't handle the case where intermediate permutes have multiple users. Updating to a more concise and general approach. Reviewed By: DrJessop Differential Revision: D97000790 --- backends/cadence/aot/compiler_utils.py | 22 ------ backends/cadence/aot/fuse_ops.py | 70 +++++++------------ .../aot/tests/test_fusion_ops_passes.py | 66 +++++++++++++++++ 3 files changed, 93 insertions(+), 65 deletions(-) diff --git a/backends/cadence/aot/compiler_utils.py b/backends/cadence/aot/compiler_utils.py index b55d388691f..eff3f49abbf 100644 --- a/backends/cadence/aot/compiler_utils.py +++ b/backends/cadence/aot/compiler_utils.py @@ -87,28 +87,6 @@ def broadcastable(shape_1: Sequence[int], shape_2: Sequence[int]) -> bool: ) -# Return a chain of nodes with target in op_targets -def get_cascaded_ops( - nodes: List[torch.fx.Node], - # pyre-fixme[2]: Parameter annotation cannot contain `Any`. - op_targets: Iterable[Union[Callable[..., Any], str]], -) -> Sequence[torch.fx.Node]: - """ - 'nodes' contains a chain of ops with target in 'op_targets'. Extend that chain - by one if nodes[-1] has a single user with its op target in 'op_targets'. - """ - cur = nodes[-1] - users = list(cur.users.keys()) - # Assert that (a) there is only one user of cur, and (b) that user is - # one of the op in op_targets. - if len(users) == 1 and users[0].target in op_targets: - nodes.append(users[0]) - # Recursively find the chain starting at the user - return get_cascaded_ops(nodes, op_targets) - - return nodes - - def get_transposed_dims( node: torch.fx.Node, dims: Optional[List[int]] = None ) -> List[int]: diff --git a/backends/cadence/aot/fuse_ops.py b/backends/cadence/aot/fuse_ops.py index e71803c03bb..023a6f5760a 100644 --- a/backends/cadence/aot/fuse_ops.py +++ b/backends/cadence/aot/fuse_ops.py @@ -22,7 +22,6 @@ import torch.fx from executorch.backends.cadence.aot.compiler_utils import ( broadcastable, - get_cascaded_ops, get_permuted_dims, get_scale, get_tensor_from_attr, @@ -581,7 +580,8 @@ def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: @register_cadence_pass(CadencePassAttribute(opt_level=1)) class FuseCascadedTransposeOrPermuteOps(RemoveOrReplacePassInterface): """ - Fuse a cascaded chain of transpose and permute ops + Fuse a chain of transpose and permute ops into a single permute or a no-op. + Handles branches and chains permutes. """ transpose_or_permute_target = { @@ -594,55 +594,39 @@ def targets(self) -> list[EdgeOpOverload]: return list(self.transpose_or_permute_target) def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool: - # Get the cascaded chain of transpose/permute ops starting at node - cascaded_transpose_or_permute_ops = get_cascaded_ops( - [node], self.transpose_or_permute_target - ) - # The chain must have more than 1 node - if len(cascaded_transpose_or_permute_ops) == 1: + # Fuse with the parent node if it's also a permute or a transpose. Since the + # pass interface traverses all ops in order the pass will properly fuse a chain + # of permutes. + parent_node = get_arg(node, "input", torch.fx.Node) + if parent_node.target not in self.transpose_or_permute_target: return False + input_of_parent = get_arg(parent_node, "input", torch.fx.Node) - # Get shape from node metadata - val = node.meta.get("val") - if val is None: - return False - out_shape = val.shape - out_dims = len(out_shape) - - # This is the trivial dimension order - dims = list(range(out_dims)) - # Compute the effect of the chain on dims - for tp in cascaded_transpose_or_permute_ops: - dims = ( - get_transposed_dims(tp, dims) - if tp.target == exir_ops.edge.aten.transpose_copy.int - else get_permuted_dims(tp, dims) - ) + # Compute combined effect of permutes. + dims = list(range(node.meta["val"].ndim)) - graph = node.graph + if parent_node.target == exir_ops.edge.aten.transpose_copy.int: + dims = get_transposed_dims(parent_node, dims) + else: + dims = get_permuted_dims(parent_node, dims) - # In case the permute chain cancelled each other, the final dims will - # be the same as the initial order. In that case, the chain was nop. - # Otherwise create a new permute op that encompasses the effect of the - # chain. - if dims == list(range(out_dims)): - cascaded_transpose_or_permute_ops[-1].replace_all_uses_with( - cast(torch.fx.Node, node.args[0]) - ) + if node.target == exir_ops.edge.aten.transpose_copy.int: + dims = get_transposed_dims(node, dims) else: - with graph.inserting_before(cascaded_transpose_or_permute_ops[-1]): - new_permute = graph.call_function( + dims = get_permuted_dims(node, dims) + + # If combined effect is identity replace the node with input. + if dims == sorted(dims): + node.replace_all_uses_with(input_of_parent) + else: + with node.graph.inserting_before(node): + new_permute = node.graph.call_function( exir_ops.edge.aten.permute_copy.default, - args=(node.args[0], dims), + args=(input_of_parent, dims), ) - new_permute.meta = cascaded_transpose_or_permute_ops[-1].meta - cascaded_transpose_or_permute_ops[-1].replace_all_uses_with(new_permute) - - # Now erase the chain (except the first node which will be handled by the interface) - for tp in reversed(cascaded_transpose_or_permute_ops[1:]): - graph.erase_node(tp) + new_permute.meta = node.meta + node.replace_all_uses_with(new_permute) - # Return True to indicate the first node in the chain should be removed return True diff --git a/backends/cadence/aot/tests/test_fusion_ops_passes.py b/backends/cadence/aot/tests/test_fusion_ops_passes.py index cc7e36b4695..b1e73dce94c 100644 --- a/backends/cadence/aot/tests/test_fusion_ops_passes.py +++ b/backends/cadence/aot/tests/test_fusion_ops_passes.py @@ -289,6 +289,72 @@ def test_permute_transpose_fusion(self) -> None: graph_copy, converted_graph, (x_input,), "FuseCascadedTransposeOrPermuteOps" ) + def test_cascaded_permutes_multiple_users(self) -> None: + # Test case where intermediate permute has multiple users. + # x + # | + # permute1 + # / | \ + # permute2 permute3 permute4 + # | | | + # out0 out1 permute5 + # | + # out2 + + builder = GraphBuilder() + x_input = torch.randn(2, 3, 8, 8, dtype=torch.float32) + x = builder.placeholder("x", x_input) + permute1 = builder.call_operator( + op=exir_ops.edge.aten.permute_copy.default, + args=(x, [0, 2, 3, 1]), + ) + permute2 = builder.call_operator( + op=exir_ops.edge.aten.permute_copy.default, + args=(permute1, [0, 3, 1, 2]), + ) + permute3 = builder.call_operator( + op=exir_ops.edge.aten.permute_copy.default, + args=(permute1, [0, 1, 3, 2]), + ) + permute4 = builder.call_operator( + op=exir_ops.edge.aten.permute_copy.default, + args=(permute1, [3, 2, 1, 0]), + ) + permute5 = builder.call_operator( + op=exir_ops.edge.aten.permute_copy.default, + args=(permute4, [1, 2, 3, 0]), + ) + builder.output([permute2, permute3, permute5]) + original_graph = builder.get_graph_module() + graph_copy = copy.deepcopy(original_graph) + + p = FuseCascadedTransposeOrPermuteOps() + result = p.call(original_graph) + self.assertTrue(result.modified) + converted_graph = result.graph_module + + # permute2 becomes a no-op, permute3 and permute5 fused with preceding permutes + # into new single permutes. + output0, output1, output2 = converted_graph.graph.output_node().args[0] + # out0: permute1 + permute2 = identity, so it connects to the graph input. + graph_input = converted_graph.graph.find_nodes(op="placeholder")[0] + self.assertIs(output0, graph_input) + # out1: permute1 [0,2,3,1] + permute3 [0,1,3,2] fused to [0,2,1,3]. + self.assertEqual(output1.target, exir_ops.edge.aten.permute_copy.default) + self.assertIs(output1.args[0], graph_input) + self.assertEqual(output1.args[1], [0, 2, 1, 3]) + # out2: permute1 [0,2,3,1] + permute4 [3,2,1,0] + permute5 [1,2,3,0] + # fused to [3,2,0,1]. + self.assertEqual(output2.target, exir_ops.edge.aten.permute_copy.default) + self.assertIs(output2.args[0], graph_input) + self.assertEqual(output2.args[1], [3, 2, 0, 1]) + validate_numerics( + graph_copy, + converted_graph, + (x_input,), + "FuseCascadedTransposeOrPermuteOps_multiple_users", + ) + def test_view_fusion(self) -> None: builder = GraphBuilder() x_input = torch.randn(8, 5, 3, dtype=torch.float32)