Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 0 additions & 22 deletions backends/cadence/aot/compiler_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,28 +87,6 @@ def broadcastable(shape_1: Sequence[int], shape_2: Sequence[int]) -> bool:
)


# Return a chain of nodes with target in op_targets
def get_cascaded_ops(
nodes: List[torch.fx.Node],
# pyre-fixme[2]: Parameter annotation cannot contain `Any`.
op_targets: Iterable[Union[Callable[..., Any], str]],
) -> Sequence[torch.fx.Node]:
"""
'nodes' contains a chain of ops with target in 'op_targets'. Extend that chain
by one if nodes[-1] has a single user with its op target in 'op_targets'.
"""
cur = nodes[-1]
users = list(cur.users.keys())
# Assert that (a) there is only one user of cur, and (b) that user is
# one of the op in op_targets.
if len(users) == 1 and users[0].target in op_targets:
nodes.append(users[0])
# Recursively find the chain starting at the user
return get_cascaded_ops(nodes, op_targets)

return nodes


def get_transposed_dims(
node: torch.fx.Node, dims: Optional[List[int]] = None
) -> List[int]:
Expand Down
70 changes: 27 additions & 43 deletions backends/cadence/aot/fuse_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import torch.fx
from executorch.backends.cadence.aot.compiler_utils import (
broadcastable,
get_cascaded_ops,
get_permuted_dims,
get_scale,
get_tensor_from_attr,
Expand Down Expand Up @@ -581,7 +580,8 @@ def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:
@register_cadence_pass(CadencePassAttribute(opt_level=1))
class FuseCascadedTransposeOrPermuteOps(RemoveOrReplacePassInterface):
"""
Fuse a cascaded chain of transpose and permute ops
Fuse a chain of transpose and permute ops into a single permute or a no-op.
Handles branches and chains permutes.
"""

transpose_or_permute_target = {
Expand All @@ -594,55 +594,39 @@ def targets(self) -> list[EdgeOpOverload]:
return list(self.transpose_or_permute_target)

def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:
# Get the cascaded chain of transpose/permute ops starting at node
cascaded_transpose_or_permute_ops = get_cascaded_ops(
[node], self.transpose_or_permute_target
)
# The chain must have more than 1 node
if len(cascaded_transpose_or_permute_ops) == 1:
# Fuse with the parent node if it's also a permute or a transpose. Since the
# pass interface traverses all ops in order the pass will properly fuse a chain
# of permutes.
parent_node = get_arg(node, "input", torch.fx.Node)
if parent_node.target not in self.transpose_or_permute_target:
return False
input_of_parent = get_arg(parent_node, "input", torch.fx.Node)

# Get shape from node metadata
val = node.meta.get("val")
if val is None:
return False
out_shape = val.shape
out_dims = len(out_shape)

# This is the trivial dimension order
dims = list(range(out_dims))
# Compute the effect of the chain on dims
for tp in cascaded_transpose_or_permute_ops:
dims = (
get_transposed_dims(tp, dims)
if tp.target == exir_ops.edge.aten.transpose_copy.int
else get_permuted_dims(tp, dims)
)
# Compute combined effect of permutes.
dims = list(range(node.meta["val"].ndim))

graph = node.graph
if parent_node.target == exir_ops.edge.aten.transpose_copy.int:
dims = get_transposed_dims(parent_node, dims)
else:
dims = get_permuted_dims(parent_node, dims)

# In case the permute chain cancelled each other, the final dims will
# be the same as the initial order. In that case, the chain was nop.
# Otherwise create a new permute op that encompasses the effect of the
# chain.
if dims == list(range(out_dims)):
cascaded_transpose_or_permute_ops[-1].replace_all_uses_with(
cast(torch.fx.Node, node.args[0])
)
if node.target == exir_ops.edge.aten.transpose_copy.int:
dims = get_transposed_dims(node, dims)
else:
with graph.inserting_before(cascaded_transpose_or_permute_ops[-1]):
new_permute = graph.call_function(
dims = get_permuted_dims(node, dims)

# If combined effect is identity replace the node with input.
if dims == sorted(dims):
node.replace_all_uses_with(input_of_parent)
else:
with node.graph.inserting_before(node):
new_permute = node.graph.call_function(
exir_ops.edge.aten.permute_copy.default,
args=(node.args[0], dims),
args=(input_of_parent, dims),
)
new_permute.meta = cascaded_transpose_or_permute_ops[-1].meta
cascaded_transpose_or_permute_ops[-1].replace_all_uses_with(new_permute)

# Now erase the chain (except the first node which will be handled by the interface)
for tp in reversed(cascaded_transpose_or_permute_ops[1:]):
graph.erase_node(tp)
new_permute.meta = node.meta
node.replace_all_uses_with(new_permute)

# Return True to indicate the first node in the chain should be removed
return True


Expand Down
66 changes: 66 additions & 0 deletions backends/cadence/aot/tests/test_fusion_ops_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,72 @@ def test_permute_transpose_fusion(self) -> None:
graph_copy, converted_graph, (x_input,), "FuseCascadedTransposeOrPermuteOps"
)

def test_cascaded_permutes_multiple_users(self) -> None:
# Test case where intermediate permute has multiple users.
# x
# |
# permute1
# / | \
# permute2 permute3 permute4
# | | |
# out0 out1 permute5
# |
# out2

builder = GraphBuilder()
x_input = torch.randn(2, 3, 8, 8, dtype=torch.float32)
x = builder.placeholder("x", x_input)
permute1 = builder.call_operator(
op=exir_ops.edge.aten.permute_copy.default,
args=(x, [0, 2, 3, 1]),
)
permute2 = builder.call_operator(
op=exir_ops.edge.aten.permute_copy.default,
args=(permute1, [0, 3, 1, 2]),
)
permute3 = builder.call_operator(
op=exir_ops.edge.aten.permute_copy.default,
args=(permute1, [0, 1, 3, 2]),
)
permute4 = builder.call_operator(
op=exir_ops.edge.aten.permute_copy.default,
args=(permute1, [3, 2, 1, 0]),
)
permute5 = builder.call_operator(
op=exir_ops.edge.aten.permute_copy.default,
args=(permute4, [1, 2, 3, 0]),
)
builder.output([permute2, permute3, permute5])
original_graph = builder.get_graph_module()
graph_copy = copy.deepcopy(original_graph)

p = FuseCascadedTransposeOrPermuteOps()
result = p.call(original_graph)
self.assertTrue(result.modified)
converted_graph = result.graph_module

# permute2 becomes a no-op, permute3 and permute5 fused with preceding permutes
# into new single permutes.
output0, output1, output2 = converted_graph.graph.output_node().args[0]
# out0: permute1 + permute2 = identity, so it connects to the graph input.
graph_input = converted_graph.graph.find_nodes(op="placeholder")[0]
self.assertIs(output0, graph_input)
# out1: permute1 [0,2,3,1] + permute3 [0,1,3,2] fused to [0,2,1,3].
self.assertEqual(output1.target, exir_ops.edge.aten.permute_copy.default)
self.assertIs(output1.args[0], graph_input)
self.assertEqual(output1.args[1], [0, 2, 1, 3])
# out2: permute1 [0,2,3,1] + permute4 [3,2,1,0] + permute5 [1,2,3,0]
# fused to [3,2,0,1].
self.assertEqual(output2.target, exir_ops.edge.aten.permute_copy.default)
self.assertIs(output2.args[0], graph_input)
self.assertEqual(output2.args[1], [3, 2, 0, 1])
validate_numerics(
graph_copy,
converted_graph,
(x_input,),
"FuseCascadedTransposeOrPermuteOps_multiple_users",
)

def test_view_fusion(self) -> None:
builder = GraphBuilder()
x_input = torch.randn(8, 5, 3, dtype=torch.float32)
Expand Down
Loading