Skip to content

Commit 8164074

Browse files
committed
change is_materializable -> is_materialized / has_taggable_materialization
1 parent a0be1b3 commit 8164074

3 files changed

Lines changed: 56 additions & 34 deletions

File tree

pytato/analysis/__init__.py

Lines changed: 21 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
ShapeType,
5252
Stack,
5353
)
54+
from pytato.diagnostic import CannotBeLoweredToIndexLambda
5455
from pytato.function import Call, FunctionDefinition, NamedCallResult
5556
from pytato.scalar_expr import (
5657
FlopCounter as ScalarFlopCounter,
@@ -66,7 +67,7 @@
6667
map_and_copy,
6768
)
6869
from pytato.transform.lower_to_index_lambda import to_index_lambda
69-
from pytato.utils import is_materializable
70+
from pytato.utils import has_taggable_materialization, is_materialized
7071

7172

7273
if TYPE_CHECKING:
@@ -758,18 +759,6 @@ def update_for_Array(self, key_hash: Any, key: Any) -> None:
758759

759760
# {{{ flop counting
760761

761-
def _is_materialized(expr: ArrayOrNames | FunctionDefinition) -> bool:
762-
return (
763-
is_materializable(expr)
764-
and bool(expr.tags_of_type(ImplStored)))
765-
766-
767-
def _is_unmaterialized(expr: ArrayOrNames | FunctionDefinition) -> bool:
768-
return (
769-
is_materializable(expr)
770-
and not bool(expr.tags_of_type(ImplStored)))
771-
772-
773762
@dataclass
774763
class UndefinedOpFlopCountError(ValueError):
775764
op_name: str
@@ -787,7 +776,10 @@ def combine(self, *args: int) -> int:
787776
return sum(args)
788777

789778
def _get_own_flop_count(self, expr: Array) -> int:
790-
nflops = self.scalar_flop_counter(to_index_lambda(expr).expr)
779+
try:
780+
nflops = self.scalar_flop_counter(to_index_lambda(expr).expr)
781+
except CannotBeLoweredToIndexLambda:
782+
nflops = 0
791783
if not isinstance(nflops, int):
792784
from pytato.scalar_expr import InputGatherer as ScalarInputGatherer
793785
var_names: set[str] = set(ScalarInputGatherer()(nflops))
@@ -805,8 +797,7 @@ def rec(self, expr: ArrayOrNames) -> int:
805797
return self._cache_retrieve(inputs)
806798
except KeyError:
807799
result: int
808-
if _is_unmaterialized(expr):
809-
assert isinstance(expr, Array)
800+
if isinstance(expr, Array) and not is_materialized(expr):
810801
result = (
811802
self._get_own_flop_count(expr)
812803
# Intentionally going to Mapper instead of super() to avoid
@@ -866,14 +857,16 @@ def map_call(self, expr: Call) -> None:
866857

867858
@override
868859
def post_visit(self, expr: ArrayOrNames | FunctionDefinition) -> None:
869-
if not _is_materialized(expr):
860+
if not is_materialized(expr):
870861
return
871862
assert isinstance(expr, Array)
872-
unmaterialized_expr = expr.without_tags(ImplStored())
873-
self._per_entry_flop_counter(unmaterialized_expr)
874-
self.materialized_node_to_nflops[expr] = (
875-
product(expr.shape)
876-
* self._per_entry_flop_counter.node_to_nflops[unmaterialized_expr])
863+
if has_taggable_materialization(expr):
864+
unmaterialized_expr = expr.without_tags(ImplStored())
865+
self.materialized_node_to_nflops[expr] = (
866+
product(expr.shape)
867+
* self._per_entry_flop_counter(unmaterialized_expr))
868+
else:
869+
self.materialized_node_to_nflops[expr] = 0
877870

878871

879872
class _UnmaterializedSubexpressionUseCounter(CombineMapper[dict[Array, int], Never]):
@@ -892,8 +885,7 @@ def rec(self, expr: ArrayOrNames) -> dict[Array, int]:
892885
return self._cache_retrieve(inputs)
893886
except KeyError:
894887
result: dict[Array, int]
895-
if _is_unmaterialized(expr):
896-
assert isinstance(expr, Array)
888+
if isinstance(expr, Array) and not is_materialized(expr):
897889
# Intentionally going to Mapper instead of super() to avoid
898890
# double caching when subclasses of CachedMapper override rec,
899891
# see https://github.com/inducer/pytato/pull/585
@@ -959,7 +951,7 @@ def map_call(self, expr: Call) -> None:
959951

960952
@override
961953
def post_visit(self, expr: ArrayOrNames | FunctionDefinition) -> None:
962-
if not _is_materialized(expr):
954+
if not is_materialized(expr) or not has_taggable_materialization(expr):
963955
return
964956
assert isinstance(expr, Array)
965957
unmaterialized_expr = expr.without_tags(ImplStored())
@@ -985,7 +977,10 @@ def _normalize_materialization(expr: ArrayOrNamesTc) -> ArrayOrNamesTc:
985977
# Make sure outputs are materialized
986978
if isinstance(expr, DictOfNamedArrays):
987979
output_to_materialized_output: dict[Array, Array] = {
988-
ary: ary.tagged(ImplStored()) if is_materializable(ary) else ary
980+
ary: (
981+
ary.tagged(ImplStored())
982+
if has_taggable_materialization(ary)
983+
else ary)
989984
for ary in expr._data.values()}
990985

991986
def replace_with_materialized(ary: ArrayOrNames) -> ArrayOrNames:

pytato/utils.py

Lines changed: 30 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,8 @@
8282
.. autofunction:: dim_to_index_lambda_components
8383
.. autofunction:: get_common_dtype_of_ary_or_scalars
8484
.. autofunction:: get_einsum_subscript_str
85-
.. autofunction:: is_materializable
85+
.. autofunction:: is_materialized
86+
.. autofunction:: has_taggable_materialization
8687
8788
References
8889
^^^^^^^^^^
@@ -739,18 +740,40 @@ def get_einsum_specification(expr: Einsum) -> str:
739740
return f"{','.join(input_specs)}->{output_spec}"
740741

741742

742-
def is_materializable(expr: ArrayOrNames | FunctionDefinition) -> bool:
743+
def is_materialized(expr: ArrayOrNames | FunctionDefinition) -> bool:
744+
"""Returns *True* if *expr* is materialized."""
745+
from pytato.array import InputArgumentBase
746+
from pytato.distributed.nodes import DistributedRecv
747+
from pytato.tags import ImplStored
748+
return (
749+
(
750+
isinstance(expr, Array)
751+
and bool(expr.tags_of_type(ImplStored)))
752+
or isinstance(
753+
expr,
754+
(
755+
# FIXME: Is there a nice way to generalize this?
756+
InputArgumentBase,
757+
DistributedRecv)))
758+
759+
760+
def has_taggable_materialization(expr: ArrayOrNames | FunctionDefinition) -> bool:
743761
"""
744-
Returns *True* if *expr* is an instance of an array type that can be materialized.
762+
Returns *True* if *expr* uses the :class:`pytato.tags.ImplStored` tag to
763+
determine whether or not it is materialized.
745764
"""
746765
from pytato.array import InputArgumentBase, NamedArray
747766
from pytato.distributed.nodes import DistributedRecv, DistributedSendRefHolder
748767
return (
749768
isinstance(expr, Array)
750-
and not isinstance(expr, (
751-
# FIXME: Is there a nice way to generalize this?
752-
InputArgumentBase, NamedArray, DistributedRecv,
753-
DistributedSendRefHolder)))
769+
and not isinstance(
770+
expr,
771+
(
772+
# FIXME: Is there a nice way to generalize this?
773+
InputArgumentBase,
774+
DistributedRecv,
775+
NamedArray,
776+
DistributedSendRefHolder)))
754777

755778

756779
# vim: fdm=marker

test/test_pytato.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -980,9 +980,13 @@ def test_materialized_node_flop_counts():
980980

981981
# z[i, j] = x[i, j] + y[i, j]
982982
# expr[i, j] = 2*z[i, j] + (-1)*3*z[i, j]
983-
assert len(materialized_node_to_flop_count) == 2
983+
assert len(materialized_node_to_flop_count) == 4
984+
assert x in materialized_node_to_flop_count
985+
assert y in materialized_node_to_flop_count
984986
assert z in materialized_node_to_flop_count
985987
assert expr.tagged(ImplStored()) in materialized_node_to_flop_count
988+
assert materialized_node_to_flop_count[x] == 0
989+
assert materialized_node_to_flop_count[y] == 0
986990
assert materialized_node_to_flop_count[z] == 40
987991
assert materialized_node_to_flop_count[expr.tagged(ImplStored())] == 40*4
988992

0 commit comments

Comments
 (0)