2222import torch .fx
2323from executorch .backends .cadence .aot .compiler_utils import (
2424 broadcastable ,
25- get_cascaded_ops ,
2625 get_permuted_dims ,
2726 get_scale ,
2827 get_tensor_from_attr ,
@@ -581,7 +580,7 @@ def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:
581580@register_cadence_pass (CadencePassAttribute (opt_level = 1 ))
582581class FuseCascadedTransposeOrPermuteOps (RemoveOrReplacePassInterface ):
583582 """
584- Fuse a cascaded chain of transpose and permute ops
583+ Fuse a chain of transpose and permute ops into a single permute or a no-op.
585584 """
586585
587586 transpose_or_permute_target = {
@@ -594,55 +593,44 @@ def targets(self) -> list[EdgeOpOverload]:
594593 return list (self .transpose_or_permute_target )
595594
596595 def maybe_remove_or_replace (self , node : torch .fx .Node ) -> bool :
597- # Get the cascaded chain of transpose/permute ops starting at node
598- cascaded_transpose_or_permute_ops = get_cascaded_ops (
599- [node ], self .transpose_or_permute_target
600- )
601- # The chain must have more than 1 node
602- if len (cascaded_transpose_or_permute_ops ) == 1 :
596+ if "val" not in node .meta :
603597 return False
598+ rank = len (node .meta ["val" ].shape )
599+
600+ # Walk up the graph collecting consecutive permute/transpose ops.
601+ chain = [node ]
602+ input_node = node .args [0 ]
603+ while (
604+ isinstance (input_node , torch .fx .Node )
605+ and input_node .target in self .transpose_or_permute_target
606+ ):
607+ chain .append (input_node )
608+ input_node = input_node .args [0 ]
604609
605- # Get shape from node metadata
606- val = node .meta .get ("val" )
607- if val is None :
610+ if len (chain ) < 2 :
608611 return False
609- out_shape = val .shape
610- out_dims = len (out_shape )
611-
612- # This is the trivial dimension order
613- dims = list (range (out_dims ))
614- # Compute the effect of the chain on dims
615- for tp in cascaded_transpose_or_permute_ops :
616- dims = (
617- get_transposed_dims (tp , dims )
618- if tp .target == exir_ops .edge .aten .transpose_copy .int
619- else get_permuted_dims (tp , dims )
620- )
621612
622- graph = node .graph
613+ # Compute combined effect of permutes (chain is populated in reverse order).
614+ dims = list (range (rank ))
615+ for op in reversed (chain ):
616+ if op .target == exir_ops .edge .aten .transpose_copy .int :
617+ dims = get_transposed_dims (op , dims )
618+ else :
619+ assert op .target == exir_ops .edge .aten .permute_copy .default
620+ dims = get_permuted_dims (op , dims )
623621
624- # In case the permute chain cancelled each other, the final dims will
625- # be the same as the initial order. In that case, the chain was nop.
626- # Otherwise create a new permute op that encompasses the effect of the
627- # chain.
628- if dims == list (range (out_dims )):
629- cascaded_transpose_or_permute_ops [- 1 ].replace_all_uses_with (
630- cast (torch .fx .Node , node .args [0 ])
631- )
622+ # If combined effect is identity replace the node with input.
623+ if dims == list (range (rank )):
624+ node .replace_all_uses_with (input_node )
632625 else :
633- with graph .inserting_before (cascaded_transpose_or_permute_ops [ - 1 ] ):
634- new_permute = graph .call_function (
626+ with node . graph .inserting_before (node ):
627+ new_permute = node . graph .call_function (
635628 exir_ops .edge .aten .permute_copy .default ,
636- args = (node . args [ 0 ] , dims ),
629+ args = (input_node , dims ),
637630 )
638- new_permute .meta = cascaded_transpose_or_permute_ops [- 1 ].meta
639- cascaded_transpose_or_permute_ops [- 1 ].replace_all_uses_with (new_permute )
640-
641- # Now erase the chain (except the first node which will be handled by the interface)
642- for tp in reversed (cascaded_transpose_or_permute_ops [1 :]):
643- graph .erase_node (tp )
631+ new_permute .meta = node .meta
632+ node .replace_all_uses_with (new_permute )
644633
645- # Return True to indicate the first node in the chain should be removed
646634 return True
647635
648636
0 commit comments