Skip to content

Commit 7fdb60e

Browse files
mcremon-metameta-codesync[bot]
authored andcommitted
Update permute removal pass to handle binary operations, and cleanup better (#18256)
Summary: Pull Request resolved: #18256 As titled. It is currently not cleaning up as much as it should, and the pass is only capable of handling single input cases. Result: from 9 to 1 (minimum by construction) permutes on Wake Gesture. Differential Revision: D96940254 Reviewed By: abeakkas
1 parent c5dfd18 commit 7fdb60e

File tree

4 files changed

+13
-2
lines changed

4 files changed

+13
-2
lines changed

backends/cadence/aot/BUCK

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -300,6 +300,7 @@ fbcode_target(_kind = runtime.python_library,
300300
],
301301
typing = True,
302302
deps = [
303+
":fuse_ops",
303304
":ops_registrations",
304305
"//caffe2:torch",
305306
"//executorch/backends/cadence/aot:pass_utils",

backends/cadence/aot/fuse_ops.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1170,7 +1170,10 @@ def can_fuse_for_chain(
11701170
return False
11711171

11721172
# checking that permut2(permut1(identity)) == identity, modulo unitary dimensions
1173-
input_shape = cast(torch.fx.Node, producer.args[0]).meta["val"].shape
1173+
producer_input = cast(torch.fx.Node, producer.args[0])
1174+
if "val" not in producer_input.meta:
1175+
return False
1176+
input_shape = producer_input.meta["val"].shape
11741177
ident_dims = list(range(len(input_shape)))
11751178
# this mapping helps to handle both transpose and permutations
11761179
f: dict[Any, Callable] = {

backends/cadence/aot/passes.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from executorch.backends.cadence.aot.remove_ops import (
2626
CadenceRemoveNops,
2727
RemoveNopSliceOrViewOpPass,
28+
RemovePermutesAroundElementwiseOps,
2829
RemoveRedundantOps,
2930
)
3031
from executorch.backends.cadence.aot.reorder_ops import CadenceReorderOpsInGraph
@@ -89,6 +90,7 @@ def get_passes_in_default_order() -> list[Type[ExportPass]]:
8990
CadenceSimplifyOpsInGraph.passes,
9091
FinalizePipeline,
9192
FuseFullThenReshapePass,
93+
RemovePermutesAroundElementwiseOps,
9294
FuseTransposeOrPermuteOpPairsPass,
9395
RemoveNopSliceOrViewOpPass,
9496
CompileTimeTypeDispatchPass,

backends/cadence/aot/remove_ops.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,15 @@
1414

1515
import torch
1616
import torch.fx
17+
18+
from executorch.backends.cadence.aot.fuse_ops import FuseTransposeOrPermuteOpPairsPass
1719
from executorch.backends.cadence.aot.pass_utils import (
1820
CadencePassAttribute,
1921
get_arg,
2022
register_cadence_pass,
2123
RemoveOrReplacePassInterface,
2224
set_arg,
2325
)
24-
2526
from executorch.backends.cadence.aot.simplify_ops import SimplifySliceOpPass
2627
from executorch.backends.cadence.aot.utils import get_edge_overload_packet
2728
from executorch.backends.transforms.remove_clone_ops import RemoveCloneOpsTransform
@@ -412,6 +413,9 @@ class Subgraph:
412413
exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default,
413414
exir_ops.edge.cadence.quantize_per_tensor.default,
414415
exir_ops.edge.cadence.dequantize_per_tensor.default,
416+
exir_ops.edge.cadence.quantized_relu.per_tensor,
417+
exir_ops.edge.cadence.requantize.per_tensor,
418+
exir_ops.edge.cadence.quantized_add.per_tensor,
415419
# Ops that require special handling.
416420
exir_ops.edge.aten.cat.default,
417421
exir_ops.edge.aten.mean.dim,
@@ -804,6 +808,7 @@ class CommonRemovePasses:
804808
RemoveToOpsPass,
805809
RemoveZeroSizedCatArgsPass,
806810
RemovePermutesAroundElementwiseOps,
811+
FuseTransposeOrPermuteOpPairsPass,
807812
RemoveSqueezeViewBeforeElementwiseOps,
808813
RemoveCatFromSliceCopyPass,
809814
RemoveCloneOpsTransformImported,

0 commit comments

Comments
 (0)