Skip to content

Commit 5236448

Browse files
committed
add note to docs about assumptions when handling conditional expressions
1 parent 0d23c69 commit 5236448

1 file changed

Lines changed: 24 additions & 1 deletion

File tree

pytato/analysis/__init__.py

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1016,11 +1016,20 @@ def get_default_op_name_to_num_flops() -> dict[str, int]:
10161016
"max": 1}
10171017

10181018

1019+
# FIXME: Should the cost of "If" be the max of the two branches, or the sum?
10191020
def get_num_flops(
10201021
expr: ArrayOrNames,
10211022
op_name_to_num_flops: Mapping[str, int] | None = None,
10221023
) -> ArrayOrScalar:
1023-
"""Count the total number of floating point operations in the DAG *expr*."""
1024+
"""
1025+
Count the total number of floating point operations in the DAG *expr*.
1026+
1027+
.. note::
1028+
1029+
For arrays whose index lambda form contains :class:`pymbolic.primitives.If`,
1030+
this function assumes a SIMT-like model of computation in which the per-entry
1031+
cost is the maximum(??? FIXME) of the costs of the two branches.
1032+
"""
10241033
from pytato.codegen import normalize_outputs
10251034
expr = normalize_outputs(expr)
10261035
expr = _normalize_materialization(expr)
@@ -1036,13 +1045,20 @@ def get_num_flops(
10361045
+ sum(fc.call_to_nflops.values()))
10371046

10381047

1048+
# FIXME: Should the cost of "If" be the max of the two branches, or the sum?
10391049
def get_materialized_node_flop_counts(
10401050
expr: ArrayOrNames,
10411051
op_name_to_num_flops: Mapping[str, int] | None = None,
10421052
) -> dict[Array, ArrayOrScalar]:
10431053
"""
10441054
Returns a dictionary mapping materialized nodes in DAG *expr* to their floating
10451055
point operation count.
1056+
1057+
.. note::
1058+
1059+
For arrays whose index lambda form contains :class:`pymbolic.primitives.If`,
1060+
this function assumes a SIMT-like model of computation in which the per-entry
1061+
cost is the maximum(??? FIXME) of the costs of the two branches.
10461062
"""
10471063
from pytato.codegen import normalize_outputs
10481064
expr = normalize_outputs(expr)
@@ -1063,6 +1079,7 @@ class UnmaterializedNodeFlopCounts:
10631079
nflops_if_materialized: ArrayOrScalar
10641080

10651081

1082+
# FIXME: Should the cost of "If" be the max of the two branches, or the sum?
10661083
def get_unmaterialized_node_flop_counts(
10671084
expr: ArrayOrNames,
10681085
op_name_to_num_flops: Mapping[str, int] | None = None,
@@ -1071,6 +1088,12 @@ def get_unmaterialized_node_flop_counts(
10711088
Returns a dictionary mapping unmaterialized nodes in DAG *expr* to a
10721089
:class:`UnmaterializedNodeFlopCounts` containing floating-point operation count
10731090
information.
1091+
1092+
.. note::
1093+
1094+
For arrays whose index lambda form contains :class:`pymbolic.primitives.If`,
1095+
this function assumes a SIMT-like model of computation in which the per-entry
1096+
cost is the maximum(??? FIXME) of the costs of the two branches.
10741097
"""
10751098
from pytato.codegen import normalize_outputs
10761099
expr = normalize_outputs(expr)

0 commit comments

Comments
 (0)