Skip to content

Commit dbb50db

Browse files
committed
add placeholder class for operator flop counts that aren't specified
1 parent e887992 commit dbb50db

2 files changed

Lines changed: 44 additions & 8 deletions

File tree

pytato/analysis/__init__.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -798,13 +798,16 @@ def _get_own_flop_count(self, expr: Array) -> int:
798798
return 0
799799
nflops = self.scalar_flop_counter(to_index_lambda(expr).expr)
800800
if not isinstance(nflops, int):
801-
from pytato.scalar_expr import InputGatherer as ScalarInputGatherer
802-
var_names: set[str] = set(ScalarInputGatherer()(nflops))
803-
var_names.discard("nflops")
804-
if var_names:
805-
raise UndefinedOpFlopCountError(next(iter(var_names))) from None
801+
# Restricting to numerical result here because the flop counters that use
802+
# this mapper subsequently multiply the result by things that are
803+
# potentially arrays (e.g., shape components), and arrays and scalar
804+
# expressions are not interoperable
805+
from pytato.scalar_expr import OpFlops, OpFlopsCollector
806+
op_flops: frozenset[OpFlops] = OpFlopsCollector()(nflops)
807+
if op_flops:
808+
raise UndefinedOpFlopCountError(next(iter(op_flops)).op)
806809
else:
807-
raise AssertionError from None
810+
raise AssertionError
808811
return nflops
809812

810813
@override

pytato/scalar_expr.py

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444

4545
import re
4646
from collections.abc import Iterable, Mapping, Set
47+
from functools import reduce
4748
from typing import (
4849
TYPE_CHECKING,
4950
Any,
@@ -266,8 +267,7 @@ def _get_op_nflops(self, name: str) -> ArithmeticExpression:
266267
try:
267268
return self.op_name_to_num_flops[name]
268269
except KeyError:
269-
from pymbolic import var
270-
result = var("nflops")(var(name))
270+
result = OpFlops(name)
271271
self.op_name_to_num_flops[name] = result
272272
return result
273273

@@ -479,9 +479,42 @@ class TypeCast(ExpressionBase):
479479
dtype: np.dtype[Any]
480480
inner_expr: ScalarExpression
481481

482+
483+
@expr_dataclass()
484+
class OpFlops(prim.AlgebraicLeaf):
485+
"""
486+
Placeholder flop count for an operator.
487+
488+
.. autoattribute:: op
489+
"""
490+
op: str
491+
482492
# }}}
483493

484494

495+
class OpFlopsCollector(CombineMapper[frozenset[OpFlops], []]):
496+
"""
497+
Constructs a :class:`frozenset` containing all instances of
498+
:class:`pytato.scalar_expr.OpFlops` found in a scalar expression.
499+
"""
500+
@override
501+
def combine(
502+
self, values: Iterable[frozenset[OpFlops]]) -> frozenset[OpFlops]:
503+
return reduce(
504+
lambda x, y: x.union(y),
505+
values,
506+
cast("frozenset[OpFlops]", frozenset()))
507+
508+
@override
509+
def map_algebraic_leaf(
510+
self, expr: prim.AlgebraicLeaf) -> frozenset[OpFlops]:
511+
return frozenset([expr]) if isinstance(expr, OpFlops) else frozenset()
512+
513+
@override
514+
def map_constant(self, expr: object) -> frozenset[OpFlops]:
515+
return frozenset()
516+
517+
485518
class InductionVariableCollector(CombineMapper[Set[str], []]):
486519
def combine(self, values: Iterable[Set[str]]) -> frozenset[str]:
487520
from functools import reduce

0 commit comments

Comments
 (0)