@@ -1033,7 +1033,15 @@ def get_num_flops(
10331033 expr : ArrayOrNames ,
10341034 op_name_to_num_flops : Mapping [str , int ] | None = None ,
10351035 ) -> ArrayOrScalar :
1036- """Count the total number of floating point operations in the DAG *expr*."""
1036+ """
1037+ Count the total number of floating point operations in the DAG *expr*.
1038+
1039+ .. note::
1040+
1041+ For arrays whose index lambda form contains :class:`pymbolic.primitives.If`,
1042+ this function assumes a SIMT-like model of computation in which the per-entry
1043+ cost is the sum of the costs of the two branches.
1044+ """
10371045 from pytato .codegen import normalize_outputs
10381046 expr = normalize_outputs (expr )
10391047 expr = _normalize_materialization (expr )
@@ -1056,6 +1064,12 @@ def get_materialized_node_flop_counts(
10561064 """
10571065 Returns a dictionary mapping materialized nodes in DAG *expr* to their floating
10581066 point operation count.
1067+
1068+ .. note::
1069+
1070+ For arrays whose index lambda form contains :class:`pymbolic.primitives.If`,
1071+ this function assumes a SIMT-like model of computation in which the per-entry
1072+ cost is the sum of the costs of the two branches.
10591073 """
10601074 from pytato .codegen import normalize_outputs
10611075 expr = normalize_outputs (expr )
@@ -1078,6 +1092,12 @@ def get_unmaterialized_node_flop_counts(
10781092 Returns a dictionary mapping unmaterialized nodes in DAG *expr* to a
10791093 :class:`UnmaterializedNodeFlopCounts` containing floating-point operation count
10801094 information.
1095+
1096+ .. note::
1097+
1098+ For arrays whose index lambda form contains :class:`pymbolic.primitives.If`,
1099+ this function assumes a SIMT-like model of computation in which the per-entry
1100+ cost is the sum of the costs of the two branches.
10811101 """
10821102 from pytato .codegen import normalize_outputs
10831103 expr = normalize_outputs (expr )
0 commit comments