5151 ShapeType ,
5252 Stack ,
5353)
54+ from pytato .diagnostic import CannotBeLoweredToIndexLambda
5455from pytato .function import Call , FunctionDefinition , NamedCallResult
5556from pytato .scalar_expr import (
5657 FlopCounter as ScalarFlopCounter ,
6667 map_and_copy ,
6768)
6869from pytato .transform .lower_to_index_lambda import to_index_lambda
69- from pytato .utils import is_materializable
70+ from pytato .utils import has_taggable_materialization , is_materialized
7071
7172
7273if TYPE_CHECKING :
@@ -758,18 +759,6 @@ def update_for_Array(self, key_hash: Any, key: Any) -> None:
758759
759760# {{{ flop counting
760761
761- def _is_materialized (expr : ArrayOrNames | FunctionDefinition ) -> bool :
762- return (
763- is_materializable (expr )
764- and bool (expr .tags_of_type (ImplStored )))
765-
766-
767- def _is_unmaterialized (expr : ArrayOrNames | FunctionDefinition ) -> bool :
768- return (
769- is_materializable (expr )
770- and not bool (expr .tags_of_type (ImplStored )))
771-
772-
773762@dataclass
774763class UndefinedOpFlopCountError (ValueError ):
775764 op_name : str
@@ -787,7 +776,10 @@ def combine(self, *args: int) -> int:
787776 return sum (args )
788777
789778 def _get_own_flop_count (self , expr : Array ) -> int :
790- nflops = self .scalar_flop_counter (to_index_lambda (expr ).expr )
779+ try :
780+ nflops = self .scalar_flop_counter (to_index_lambda (expr ).expr )
781+ except CannotBeLoweredToIndexLambda :
782+ nflops = 0
791783 if not isinstance (nflops , int ):
792784 from pytato .scalar_expr import InputGatherer as ScalarInputGatherer
793785 var_names : set [str ] = set (ScalarInputGatherer ()(nflops ))
@@ -805,8 +797,7 @@ def rec(self, expr: ArrayOrNames) -> int:
805797 return self ._cache_retrieve (inputs )
806798 except KeyError :
807799 result : int
808- if _is_unmaterialized (expr ):
809- assert isinstance (expr , Array )
800+ if isinstance (expr , Array ) and not is_materialized (expr ):
810801 result = (
811802 self ._get_own_flop_count (expr )
812803 # Intentionally going to Mapper instead of super() to avoid
@@ -866,14 +857,16 @@ def map_call(self, expr: Call) -> None:
866857
867858 @override
868859 def post_visit (self , expr : ArrayOrNames | FunctionDefinition ) -> None :
869- if not _is_materialized (expr ):
860+ if not is_materialized (expr ):
870861 return
871862 assert isinstance (expr , Array )
872- unmaterialized_expr = expr .without_tags (ImplStored ())
873- self ._per_entry_flop_counter (unmaterialized_expr )
874- self .materialized_node_to_nflops [expr ] = (
875- product (expr .shape )
876- * self ._per_entry_flop_counter .node_to_nflops [unmaterialized_expr ])
863+ if has_taggable_materialization (expr ):
864+ unmaterialized_expr = expr .without_tags (ImplStored ())
865+ self .materialized_node_to_nflops [expr ] = (
866+ product (expr .shape )
867+ * self ._per_entry_flop_counter (unmaterialized_expr ))
868+ else :
869+ self .materialized_node_to_nflops [expr ] = 0
877870
878871
879872class _UnmaterializedSubexpressionUseCounter (CombineMapper [dict [Array , int ], Never ]):
@@ -892,8 +885,7 @@ def rec(self, expr: ArrayOrNames) -> dict[Array, int]:
892885 return self ._cache_retrieve (inputs )
893886 except KeyError :
894887 result : dict [Array , int ]
895- if _is_unmaterialized (expr ):
896- assert isinstance (expr , Array )
888+ if isinstance (expr , Array ) and not is_materialized (expr ):
897889 # Intentionally going to Mapper instead of super() to avoid
898890 # double caching when subclasses of CachedMapper override rec,
899891 # see https://github.com/inducer/pytato/pull/585
@@ -959,7 +951,7 @@ def map_call(self, expr: Call) -> None:
959951
960952 @override
961953 def post_visit (self , expr : ArrayOrNames | FunctionDefinition ) -> None :
962- if not _is_materialized (expr ):
954+ if not is_materialized ( expr ) or not has_taggable_materialization (expr ):
963955 return
964956 assert isinstance (expr , Array )
965957 unmaterialized_expr = expr .without_tags (ImplStored ())
@@ -985,7 +977,10 @@ def _normalize_materialization(expr: ArrayOrNamesTc) -> ArrayOrNamesTc:
985977 # Make sure outputs are materialized
986978 if isinstance (expr , DictOfNamedArrays ):
987979 output_to_materialized_output : dict [Array , Array ] = {
988- ary : ary .tagged (ImplStored ()) if is_materializable (ary ) else ary
980+ ary : (
981+ ary .tagged (ImplStored ())
982+ if has_taggable_materialization (ary )
983+ else ary )
989984 for ary in expr ._data .values ()}
990985
991986 def replace_with_materialized (ary : ArrayOrNames ) -> ArrayOrNames :
0 commit comments