@@ -1016,11 +1016,20 @@ def get_default_op_name_to_num_flops() -> dict[str, int]:
10161016 "max" : 1 }
10171017
10181018
1019+ # FIXME: Should the cost of "If" be the max of the two branches, or the sum?
10191020def get_num_flops (
10201021 expr : ArrayOrNames ,
10211022 op_name_to_num_flops : Mapping [str , int ] | None = None ,
10221023 ) -> ArrayOrScalar :
1023- """Count the total number of floating point operations in the DAG *expr*."""
1024+ """
1025+ Count the total number of floating point operations in the DAG *expr*.
1026+
1027+ .. note::
1028+
1029+ For arrays whose index lambda form contains :class:`pymbolic.primitives.If`,
1030+ this function assumes a SIMT-like model of computation in which the per-entry
1031+ cost is the maximum(??? FIXME) of the costs of the two branches.
1032+ """
10241033 from pytato .codegen import normalize_outputs
10251034 expr = normalize_outputs (expr )
10261035 expr = _normalize_materialization (expr )
@@ -1036,13 +1045,20 @@ def get_num_flops(
10361045 + sum (fc .call_to_nflops .values ()))
10371046
10381047
1048+ # FIXME: Should the cost of "If" be the max of the two branches, or the sum?
10391049def get_materialized_node_flop_counts (
10401050 expr : ArrayOrNames ,
10411051 op_name_to_num_flops : Mapping [str , int ] | None = None ,
10421052 ) -> dict [Array , ArrayOrScalar ]:
10431053 """
10441054 Returns a dictionary mapping materialized nodes in DAG *expr* to their floating
10451055 point operation count.
1056+
1057+ .. note::
1058+
1059+ For arrays whose index lambda form contains :class:`pymbolic.primitives.If`,
1060+ this function assumes a SIMT-like model of computation in which the per-entry
1061+ cost is the maximum(??? FIXME) of the costs of the two branches.
10461062 """
10471063 from pytato .codegen import normalize_outputs
10481064 expr = normalize_outputs (expr )
@@ -1063,6 +1079,7 @@ class UnmaterializedNodeFlopCounts:
10631079 nflops_if_materialized : ArrayOrScalar
10641080
10651081
1082+ # FIXME: Should the cost of "If" be the max of the two branches, or the sum?
10661083def get_unmaterialized_node_flop_counts (
10671084 expr : ArrayOrNames ,
10681085 op_name_to_num_flops : Mapping [str , int ] | None = None ,
@@ -1071,6 +1088,12 @@ def get_unmaterialized_node_flop_counts(
10711088 Returns a dictionary mapping unmaterialized nodes in DAG *expr* to a
10721089 :class:`UnmaterializedNodeFlopCounts` containing floating-point operation count
10731090 information.
1091+
1092+ .. note::
1093+
1094+ For arrays whose index lambda form contains :class:`pymbolic.primitives.If`,
1095+ this function assumes a SIMT-like model of computation in which the per-entry
1096+ cost is the maximum(??? FIXME) of the costs of the two branches.
10741097 """
10751098 from pytato .codegen import normalize_outputs
10761099 expr = normalize_outputs (expr )
0 commit comments