Skip to content

Commit 81db15b

Browse files
abeakkasfacebook-github-bot
authored andcommitted
Generalize and simplify permute fusion. (#18254)
Summary: The old algorithm can't handle the case where intermediate permutes have multiple users. Updating to a more concise and general approach. Reviewed By: DrJessop Differential Revision: D97000790
1 parent b2f0a5a commit 81db15b

3 files changed

Lines changed: 97 additions & 65 deletions

File tree

backends/cadence/aot/compiler_utils.py

Lines changed: 0 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -87,28 +87,6 @@ def broadcastable(shape_1: Sequence[int], shape_2: Sequence[int]) -> bool:
8787
)
8888

8989

90-
# Return a chain of nodes with target in op_targets
91-
def get_cascaded_ops(
92-
nodes: List[torch.fx.Node],
93-
# pyre-fixme[2]: Parameter annotation cannot contain `Any`.
94-
op_targets: Iterable[Union[Callable[..., Any], str]],
95-
) -> Sequence[torch.fx.Node]:
96-
"""
97-
'nodes' contains a chain of ops with target in 'op_targets'. Extend that chain
98-
by one if nodes[-1] has a single user with its op target in 'op_targets'.
99-
"""
100-
cur = nodes[-1]
101-
users = list(cur.users.keys())
102-
# Assert that (a) there is only one user of cur, and (b) that user is
103-
# one of the op in op_targets.
104-
if len(users) == 1 and users[0].target in op_targets:
105-
nodes.append(users[0])
106-
# Recursively find the chain starting at the user
107-
return get_cascaded_ops(nodes, op_targets)
108-
109-
return nodes
110-
111-
11290
def get_transposed_dims(
11391
node: torch.fx.Node, dims: Optional[List[int]] = None
11492
) -> List[int]:

backends/cadence/aot/fuse_ops.py

Lines changed: 27 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
import torch.fx
2323
from executorch.backends.cadence.aot.compiler_utils import (
2424
broadcastable,
25-
get_cascaded_ops,
2625
get_permuted_dims,
2726
get_scale,
2827
get_tensor_from_attr,
@@ -581,7 +580,8 @@ def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:
581580
@register_cadence_pass(CadencePassAttribute(opt_level=1))
582581
class FuseCascadedTransposeOrPermuteOps(RemoveOrReplacePassInterface):
583582
"""
584-
Fuse a cascaded chain of transpose and permute ops
583+
Fuse a chain of transpose and permute ops into a single permute or a no-op.
584+
Handles branches and chains permutes.
585585
"""
586586

587587
transpose_or_permute_target = {
@@ -594,55 +594,39 @@ def targets(self) -> list[EdgeOpOverload]:
594594
return list(self.transpose_or_permute_target)
595595

596596
def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:
597-
# Get the cascaded chain of transpose/permute ops starting at node
598-
cascaded_transpose_or_permute_ops = get_cascaded_ops(
599-
[node], self.transpose_or_permute_target
600-
)
601-
# The chain must have more than 1 node
602-
if len(cascaded_transpose_or_permute_ops) == 1:
597+
# Fuse with the parent node if it's also a permute or a transpose. Since the
598+
# pass interface traverses all ops in order the pass will properly fuse a chain
599+
# of permutes.
600+
parent_node = get_arg(node, "input", torch.fx.Node)
601+
if parent_node.target not in self.transpose_or_permute_target:
603602
return False
603+
input_of_parent = get_arg(parent_node, "input", torch.fx.Node)
604604

605-
# Get shape from node metadata
606-
val = node.meta.get("val")
607-
if val is None:
608-
return False
609-
out_shape = val.shape
610-
out_dims = len(out_shape)
611-
612-
# This is the trivial dimension order
613-
dims = list(range(out_dims))
614-
# Compute the effect of the chain on dims
615-
for tp in cascaded_transpose_or_permute_ops:
616-
dims = (
617-
get_transposed_dims(tp, dims)
618-
if tp.target == exir_ops.edge.aten.transpose_copy.int
619-
else get_permuted_dims(tp, dims)
620-
)
605+
# Compute combined effect of permutes.
606+
dims = list(range(node.meta["val"].ndim))
621607

622-
graph = node.graph
608+
if parent_node.target == exir_ops.edge.aten.transpose_copy.int:
609+
dims = get_transposed_dims(parent_node, dims)
610+
else:
611+
dims = get_permuted_dims(parent_node, dims)
623612

624-
# In case the permute chain cancelled each other, the final dims will
625-
# be the same as the initial order. In that case, the chain was nop.
626-
# Otherwise create a new permute op that encompasses the effect of the
627-
# chain.
628-
if dims == list(range(out_dims)):
629-
cascaded_transpose_or_permute_ops[-1].replace_all_uses_with(
630-
cast(torch.fx.Node, node.args[0])
631-
)
613+
if node.target == exir_ops.edge.aten.transpose_copy.int:
614+
dims = get_transposed_dims(node, dims)
632615
else:
633-
with graph.inserting_before(cascaded_transpose_or_permute_ops[-1]):
634-
new_permute = graph.call_function(
616+
dims = get_permuted_dims(node, dims)
617+
618+
# If combined effect is identity replace the node with input.
619+
if dims == sorted(dims):
620+
node.replace_all_uses_with(input_of_parent)
621+
else:
622+
with node.graph.inserting_before(node):
623+
new_permute = node.graph.call_function(
635624
exir_ops.edge.aten.permute_copy.default,
636-
args=(node.args[0], dims),
625+
args=(input_of_parent, dims),
637626
)
638-
new_permute.meta = cascaded_transpose_or_permute_ops[-1].meta
639-
cascaded_transpose_or_permute_ops[-1].replace_all_uses_with(new_permute)
640-
641-
# Now erase the chain (except the first node which will be handled by the interface)
642-
for tp in reversed(cascaded_transpose_or_permute_ops[1:]):
643-
graph.erase_node(tp)
627+
new_permute.meta = node.meta
628+
node.replace_all_uses_with(new_permute)
644629

645-
# Return True to indicate the first node in the chain should be removed
646630
return True
647631

648632

backends/cadence/aot/tests/test_fusion_ops_passes.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -289,6 +289,76 @@ def test_permute_transpose_fusion(self) -> None:
289289
graph_copy, converted_graph, (x_input,), "FuseCascadedTransposeOrPermuteOps"
290290
)
291291

292+
def test_cascaded_permutes_multiple_users(self) -> None:
293+
# Test case where intermediate permute has multiple users.
294+
# x
295+
# |
296+
# permute1
297+
# / | \
298+
# permute2 permute3 permute4
299+
# | | |
300+
# out0 out1 permute5
301+
# |
302+
# out2
303+
304+
builder = GraphBuilder()
305+
x_input = torch.randn(2, 3, 8, 8, dtype=torch.float32)
306+
x = builder.placeholder("x", x_input)
307+
permute1 = builder.call_operator(
308+
op=exir_ops.edge.aten.permute_copy.default,
309+
args=(x, [0, 2, 3, 1]),
310+
)
311+
permute2 = builder.call_operator(
312+
op=exir_ops.edge.aten.permute_copy.default,
313+
args=(permute1, [0, 3, 1, 2]),
314+
)
315+
permute3 = builder.call_operator(
316+
op=exir_ops.edge.aten.permute_copy.default,
317+
args=(permute1, [0, 1, 3, 2]),
318+
)
319+
permute4 = builder.call_operator(
320+
op=exir_ops.edge.aten.permute_copy.default,
321+
args=(permute1, [3, 2, 1, 0]),
322+
)
323+
permute5 = builder.call_operator(
324+
op=exir_ops.edge.aten.permute_copy.default,
325+
args=(permute4, [1, 2, 3, 0]),
326+
)
327+
builder.output([permute2, permute3, permute5])
328+
original_graph = builder.get_graph_module()
329+
graph_copy = copy.deepcopy(original_graph)
330+
331+
p = FuseCascadedTransposeOrPermuteOps()
332+
result = p.call(original_graph)
333+
self.assertTrue(result.modified)
334+
converted_graph = result.graph_module
335+
336+
# permute2 becomes a no-op, permute3 and permute5 fused with preceding permutes
337+
# into new single permutes.
338+
output0, output1, output2 = converted_graph.graph.output_node().args[0]
339+
# out0: permute1 + permute2 = identity, so it connects to the graph input.
340+
graph_input = converted_graph.graph.find_nodes(op="placeholder")[0]
341+
self.assertIs(output0, graph_input)
342+
# out1: permute1 [0,2,3,1] + permute3 [0,1,3,2] fused to [0,2,1,3].
343+
self.assertEqual(
344+
output1.target, exir_ops.edge.aten.permute_copy.default
345+
)
346+
self.assertIs(output1.args[0], graph_input)
347+
self.assertEqual(output1.args[1], [0, 2, 1, 3])
348+
# out2: permute1 [0,2,3,1] + permute4 [3,2,1,0] + permute5 [1,2,3,0]
349+
# fused to [3,2,0,1].
350+
self.assertEqual(
351+
output2.target, exir_ops.edge.aten.permute_copy.default
352+
)
353+
self.assertIs(output2.args[0], graph_input)
354+
self.assertEqual(output2.args[1], [3, 2, 0, 1])
355+
validate_numerics(
356+
graph_copy,
357+
converted_graph,
358+
(x_input,),
359+
"FuseCascadedTransposeOrPermuteOps_multiple_users",
360+
)
361+
292362
def test_view_fusion(self) -> None:
293363
builder = GraphBuilder()
294364
x_input = torch.randn(8, 5, 3, dtype=torch.float32)

0 commit comments

Comments
 (0)