Skip to content

Commit f077dd1

Browse files
committed
add note to docs about assumptions when handling conditional expressions
1 parent ef5470f commit f077dd1

1 file changed

Lines changed: 21 additions & 1 deletion

File tree

pytato/analysis/__init__.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1033,7 +1033,15 @@ def get_num_flops(
10331033
expr: ArrayOrNames,
10341034
op_name_to_num_flops: Mapping[str, int] | None = None,
10351035
) -> ArrayOrScalar:
1036-
"""Count the total number of floating point operations in the DAG *expr*."""
1036+
"""
1037+
Count the total number of floating point operations in the DAG *expr*.
1038+
1039+
.. note::
1040+
1041+
For arrays whose index lambda form contains :class:`pymbolic.primitives.If`,
1042+
this function assumes a SIMT-like model of computation in which the per-entry
1043+
cost is the sum of the costs of the two branches.
1044+
"""
10371045
from pytato.codegen import normalize_outputs
10381046
expr = normalize_outputs(expr)
10391047
expr = _normalize_materialization(expr)
@@ -1056,6 +1064,12 @@ def get_materialized_node_flop_counts(
10561064
"""
10571065
Returns a dictionary mapping materialized nodes in DAG *expr* to their floating
10581066
point operation count.
1067+
1068+
.. note::
1069+
1070+
For arrays whose index lambda form contains :class:`pymbolic.primitives.If`,
1071+
this function assumes a SIMT-like model of computation in which the per-entry
1072+
cost is the sum of the costs of the two branches.
10591073
"""
10601074
from pytato.codegen import normalize_outputs
10611075
expr = normalize_outputs(expr)
@@ -1078,6 +1092,12 @@ def get_unmaterialized_node_flop_counts(
10781092
Returns a dictionary mapping unmaterialized nodes in DAG *expr* to a
10791093
:class:`UnmaterializedNodeFlopCounts` containing floating-point operation count
10801094
information.
1095+
1096+
.. note::
1097+
1098+
For arrays whose index lambda form contains :class:`pymbolic.primitives.If`,
1099+
this function assumes a SIMT-like model of computation in which the per-entry
1100+
cost is the sum of the costs of the two branches.
10811101
"""
10821102
from pytato.codegen import normalize_outputs
10831103
expr = normalize_outputs(expr)

0 commit comments

Comments
 (0)