2222import torch .fx
2323from executorch .backends .cadence .aot .compiler_utils import (
2424 broadcastable ,
25- get_cascaded_ops ,
2625 get_permuted_dims ,
2726 get_scale ,
2827 get_tensor_from_attr ,
@@ -581,7 +580,8 @@ def maybe_remove_or_replace(self, node: torch.fx.Node) -> bool:
581580@register_cadence_pass (CadencePassAttribute (opt_level = 1 ))
582581class FuseCascadedTransposeOrPermuteOps (RemoveOrReplacePassInterface ):
583582 """
584- Fuse a cascaded chain of transpose and permute ops
583+ Fuse a chain of transpose and permute ops into a single permute or a no-op.
584+ Handles branches and chains permutes.
585585 """
586586
587587 transpose_or_permute_target = {
@@ -594,55 +594,39 @@ def targets(self) -> list[EdgeOpOverload]:
594594 return list (self .transpose_or_permute_target )
595595
596596 def maybe_remove_or_replace (self , node : torch .fx .Node ) -> bool :
597- # Get the cascaded chain of transpose/permute ops starting at node
598- cascaded_transpose_or_permute_ops = get_cascaded_ops (
599- [node ], self .transpose_or_permute_target
600- )
601- # The chain must have more than 1 node
602- if len (cascaded_transpose_or_permute_ops ) == 1 :
597+ # Fuse with the parent node if it's also a permute or a transpose. Since the
598+ # pass interface traverses all ops in order the pass will properly fuse a chain
599+ # of permutes.
600+ parent_node = get_arg (node , "input" , torch .fx .Node )
601+ if parent_node .target not in self .transpose_or_permute_target :
603602 return False
603+ input_of_parent = get_arg (parent_node , "input" , torch .fx .Node )
604604
605- # Get shape from node metadata
606- val = node .meta .get ("val" )
607- if val is None :
608- return False
609- out_shape = val .shape
610- out_dims = len (out_shape )
611-
612- # This is the trivial dimension order
613- dims = list (range (out_dims ))
614- # Compute the effect of the chain on dims
615- for tp in cascaded_transpose_or_permute_ops :
616- dims = (
617- get_transposed_dims (tp , dims )
618- if tp .target == exir_ops .edge .aten .transpose_copy .int
619- else get_permuted_dims (tp , dims )
620- )
605+ # Compute combined effect of permutes.
606+ dims = list (range (node .meta ["val" ].ndim ))
621607
622- graph = node .graph
608+ if parent_node .target == exir_ops .edge .aten .transpose_copy .int :
609+ dims = get_transposed_dims (parent_node , dims )
610+ else :
611+ dims : list [int ] = get_permuted_dims (parent_node , dims )
623612
624- # In case the permute chain cancelled each other, the final dims will
625- # be the same as the initial order. In that case, the chain was nop.
626- # Otherwise create a new permute op that encompasses the effect of the
627- # chain.
628- if dims == list (range (out_dims )):
629- cascaded_transpose_or_permute_ops [- 1 ].replace_all_uses_with (
630- cast (torch .fx .Node , node .args [0 ])
631- )
613+ if node .target == exir_ops .edge .aten .transpose_copy .int :
614+ dims = get_transposed_dims (node , dims )
632615 else :
633- with graph .inserting_before (cascaded_transpose_or_permute_ops [- 1 ]):
634- new_permute = graph .call_function (
616+ dims = get_permuted_dims (node , dims )
617+
618+ # If combined effect is identity replace the node with input.
619+ if dims == sorted (dims ):
620+ node .replace_all_uses_with (input_of_parent )
621+ else :
622+ with node .graph .inserting_before (node ):
623+ new_permute = node .graph .call_function (
635624 exir_ops .edge .aten .permute_copy .default ,
636- args = (node . args [ 0 ] , dims ),
625+ args = (input_of_parent , dims ),
637626 )
638- new_permute .meta = cascaded_transpose_or_permute_ops [- 1 ].meta
639- cascaded_transpose_or_permute_ops [- 1 ].replace_all_uses_with (new_permute )
640-
641- # Now erase the chain (except the first node which will be handled by the interface)
642- for tp in reversed (cascaded_transpose_or_permute_ops [1 :]):
643- graph .erase_node (tp )
627+ new_permute .meta = node .meta
628+ node .replace_all_uses_with (new_permute )
644629
645- # Return True to indicate the first node in the chain should be removed
646630 return True
647631
648632
0 commit comments