Skip to content

Commit 5e1f08b

Browse files
abeakkasfacebook-github-bot
authored andcommitted
Generalize and simplify permute fusion.
Summary: The old algorithm can't handle the case where intermediate permutes have multiple users. Updating to a more concise and general approach. Differential Revision: D97000790
1 parent ed57040 commit 5e1f08b

3 files changed

Lines changed: 73 additions & 64 deletions

File tree

backends/cadence/aot/compiler_utils.py

Lines changed: 0 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -87,28 +87,6 @@ def broadcastable(shape_1: Sequence[int], shape_2: Sequence[int]) -> bool:
8787
)
8888

8989

90-
# Return a chain of nodes with target in op_targets
91-
def get_cascaded_ops(
92-
nodes: List[torch.fx.Node],
93-
# pyre-fixme[2]: Parameter annotation cannot contain `Any`.
94-
op_targets: Iterable[Union[Callable[..., Any], str]],
95-
) -> Sequence[torch.fx.Node]:
96-
"""
97-
'nodes' contains a chain of ops with target in 'op_targets'. Extend that chain
98-
by one if nodes[-1] has a single user with its op target in 'op_targets'.
99-
"""
100-
cur = nodes[-1]
101-
users = list(cur.users.keys())
102-
# Assert that (a) there is only one user of cur, and (b) that user is
103-
# one of the op in op_targets.
104-
if len(users) == 1 and users[0].target in op_targets:
105-
nodes.append(users[0])
106-
# Recursively find the chain starting at the user
107-
return get_cascaded_ops(nodes, op_targets)
108-
109-
return nodes
110-
111-
11290
def get_transposed_dims(
11391
node: torch.fx.Node, dims: Optional[List[int]] = None
11492
) -> List[int]:

backends/cadence/aot/fuse_ops.py

Lines changed: 30 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
import torch.fx
2323
from executorch.backends.cadence.aot.compiler_utils import (
2424
broadcastable,
25-
get_cascaded_ops,
2625
get_permuted_dims,
2726
get_scale,
2827
get_tensor_from_attr,
@@ -581,7 +580,7 @@ def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:
581580
@register_cadence_pass(CadencePassAttribute(opt_level=1))
582581
class FuseCascadedTransposeOrPermuteOps(RemoveOrReplacePassInterface):
583582
"""
584-
Fuse a cascaded chain of transpose and permute ops
583+
Fuse a chain of transpose and permute ops into a single permute or a no-op.
585584
"""
586585

587586
transpose_or_permute_target = {
@@ -594,55 +593,44 @@ def targets(self) -> list[EdgeOpOverload]:
594593
return list(self.transpose_or_permute_target)
595594

596595
def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:
597-
# Get the cascaded chain of transpose/permute ops starting at node
598-
cascaded_transpose_or_permute_ops = get_cascaded_ops(
599-
[node], self.transpose_or_permute_target
600-
)
601-
# The chain must have more than 1 node
602-
if len(cascaded_transpose_or_permute_ops) == 1:
596+
if "val" not in node.meta:
603597
return False
598+
rank = len(node.meta["val"].shape)
599+
600+
# Walk up the graph collecting consecutive permute/transpose ops.
601+
chain = [node]
602+
input_node = node.args[0]
603+
while (
604+
isinstance(input_node, torch.fx.Node)
605+
and input_node.target in self.transpose_or_permute_target
606+
):
607+
chain.append(input_node)
608+
input_node = input_node.args[0]
604609

605-
# Get shape from node metadata
606-
val = node.meta.get("val")
607-
if val is None:
610+
if len(chain) < 2:
608611
return False
609-
out_shape = val.shape
610-
out_dims = len(out_shape)
611-
612-
# This is the trivial dimension order
613-
dims = list(range(out_dims))
614-
# Compute the effect of the chain on dims
615-
for tp in cascaded_transpose_or_permute_ops:
616-
dims = (
617-
get_transposed_dims(tp, dims)
618-
if tp.target == exir_ops.edge.aten.transpose_copy.int
619-
else get_permuted_dims(tp, dims)
620-
)
621612

622-
graph = node.graph
613+
# Compute combined effect of permutes (chain is populated in reverse order).
614+
dims = list(range(rank))
615+
for op in reversed(chain):
616+
if op.target == exir_ops.edge.aten.transpose_copy.int:
617+
dims = get_transposed_dims(op, dims)
618+
else:
619+
assert op.target == exir_ops.edge.aten.permute_copy.default
620+
dims = get_permuted_dims(op, dims)
623621

624-
# In case the permute chain cancelled each other, the final dims will
625-
# be the same as the initial order. In that case, the chain was nop.
626-
# Otherwise create a new permute op that encompasses the effect of the
627-
# chain.
628-
if dims == list(range(out_dims)):
629-
cascaded_transpose_or_permute_ops[-1].replace_all_uses_with(
630-
cast(torch.fx.Node, node.args[0])
631-
)
622+
# If combined effect is identity replace the node with input.
623+
if dims == list(range(rank)):
624+
node.replace_all_uses_with(input_node)
632625
else:
633-
with graph.inserting_before(cascaded_transpose_or_permute_ops[-1]):
634-
new_permute = graph.call_function(
626+
with node.graph.inserting_before(node):
627+
new_permute = node.graph.call_function(
635628
exir_ops.edge.aten.permute_copy.default,
636-
args=(node.args[0], dims),
629+
args=(input_node, dims),
637630
)
638-
new_permute.meta = cascaded_transpose_or_permute_ops[-1].meta
639-
cascaded_transpose_or_permute_ops[-1].replace_all_uses_with(new_permute)
640-
641-
# Now erase the chain (except the first node which will be handled by the interface)
642-
for tp in reversed(cascaded_transpose_or_permute_ops[1:]):
643-
graph.erase_node(tp)
631+
new_permute.meta = node.meta
632+
node.replace_all_uses_with(new_permute)
644633

645-
# Return True to indicate the first node in the chain should be removed
646634
return True
647635

648636

backends/cadence/aot/tests/test_fusion_ops_passes.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -289,6 +289,49 @@ def test_permute_transpose_fusion(self) -> None:
289289
graph_copy, converted_graph, (x_input,), "FuseCascadedTransposeOrPermuteOps"
290290
)
291291

292+
def test_cascaded_permutes_multiple_users(self) -> None:
293+
# Test case where intermediate permute has multiple users.
294+
# x
295+
# |
296+
# permute1
297+
# / \
298+
# permute2 permute3
299+
# | |
300+
# out0 out1
301+
builder = GraphBuilder()
302+
x_input = torch.randn(1, 3, 8, 8, dtype=torch.float32)
303+
x = builder.placeholder("x", x_input)
304+
permute1 = builder.call_operator(
305+
op=exir_ops.edge.aten.permute_copy.default,
306+
args=(x, [0, 2, 3, 1]),
307+
)
308+
permute2 = builder.call_operator(
309+
op=exir_ops.edge.aten.permute_copy.default,
310+
args=(permute1, [0, 3, 1, 2]),
311+
)
312+
permute3 = builder.call_operator(
313+
op=exir_ops.edge.aten.permute_copy.default,
314+
args=(permute1, [0, 1, 3, 2]),
315+
)
316+
builder.output([permute2, permute3])
317+
original_graph = builder.get_graph_module()
318+
graph_copy = copy.deepcopy(original_graph)
319+
320+
p = FuseCascadedTransposeOrPermuteOps()
321+
result = p.call(original_graph)
322+
self.assertTrue(result.modified)
323+
converted_graph = result.graph_module
324+
converted_graph.graph.eliminate_dead_code()
325+
326+
# permute2 becomes a no-op, permute3 is fused with permute1.
327+
self.assertEqual(count_node(converted_graph, exir_ops.edge.aten.permute_copy.default), 1)
328+
validate_numerics(
329+
graph_copy,
330+
converted_graph,
331+
(x_input,),
332+
"FuseCascadedTransposeOrPermuteOps_multiple_users",
333+
)
334+
292335
def test_view_fusion(self) -> None:
293336
builder = GraphBuilder()
294337
x_input = torch.randn(8, 5, 3, dtype=torch.float32)

0 commit comments

Comments
 (0)