Skip to content

Commit c885e6c

Browse files
committed
add EqualityMapper to follow pymbolic
1 parent d12c52c commit c885e6c

2 files changed

Lines changed: 85 additions & 41 deletions

File tree

loopy/symbolic.py

Lines changed: 84 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -40,25 +40,28 @@
4040
CallbackMapper as CallbackMapperBase,
4141
CSECachingMapperMixin,
4242
)
43-
from pymbolic.mapper.evaluator import \
44-
EvaluationMapper as EvaluationMapperBase
45-
from pymbolic.mapper.substitutor import \
46-
SubstitutionMapper as SubstitutionMapperBase
47-
from pymbolic.mapper.stringifier import \
48-
StringifyMapper as StringifyMapperBase
49-
from pymbolic.mapper.dependency import \
50-
DependencyMapper as DependencyMapperBase
51-
from pymbolic.mapper.coefficient import \
52-
CoefficientCollector as CoefficientCollectorBase
53-
from pymbolic.mapper.unifier import UnidirectionalUnifier \
54-
as UnidirectionalUnifierBase
55-
from pymbolic.mapper.constant_folder import \
56-
ConstantFoldingMapper as ConstantFoldingMapperBase
43+
from pymbolic.mapper.equality import (
44+
EqualityMapper as EqualityMapperBase)
45+
from pymbolic.mapper.evaluator import (
46+
EvaluationMapper as EvaluationMapperBase)
47+
from pymbolic.mapper.substitutor import (
48+
SubstitutionMapper as SubstitutionMapperBase)
49+
from pymbolic.mapper.stringifier import (
50+
StringifyMapper as StringifyMapperBase)
51+
from pymbolic.mapper.dependency import (
52+
DependencyMapper as DependencyMapperBase)
53+
from pymbolic.mapper.coefficient import (
54+
CoefficientCollector as CoefficientCollectorBase)
55+
from pymbolic.mapper.unifier import (
56+
UnidirectionalUnifier as UnidirectionalUnifierBase)
57+
from pymbolic.mapper.constant_folder import (
58+
ConstantFoldingMapper as ConstantFoldingMapperBase)
5759

5860
from pymbolic.parser import Parser as ParserBase
5961
from loopy.diagnostic import LoopyError
60-
from loopy.diagnostic import (ExpressionToAffineConversionError,
61-
UnableToDetermineAccessRangeError)
62+
from loopy.diagnostic import (
63+
ExpressionToAffineConversionError,
64+
UnableToDetermineAccessRangeError)
6265

6366

6467
import islpy as isl
@@ -114,8 +117,11 @@ def map_literal(self, expr, *args, **kwargs):
114117
return expr
115118

116119
def map_array_literal(self, expr, *args, **kwargs):
117-
return type(expr)(tuple(self.rec(ch, *args, **kwargs)
118-
for ch in expr.children))
120+
children = [self.rec(ch, *args, **kwargs) for ch in expr.children]
121+
if all(ch is orig for ch, orig in zip(children, expr.children)):
122+
return expr
123+
124+
return type(expr)(tuple(children))
119125

120126
def map_group_hw_index(self, expr, *args, **kwargs):
121127
return expr
@@ -474,6 +480,55 @@ def map_substitution(self, name, rule, arguments):
474480

475481
return self.rec(expr)
476482

483+
484+
class EqualityMapper(EqualityMapperBase):
485+
def map_loopy_function_identifier(self, expr, other) -> bool:
486+
return True
487+
488+
def map_reduction(self, expr, other) -> bool:
489+
return (
490+
expr.operation == other.operation
491+
and expr.allow_simultaneous == other.allow_simultaneous
492+
and self.rec(expr.expr, other.expr)
493+
and all(iname == other_iname
494+
for iname, other_iname in zip(expr.inames, other.inames)))
495+
496+
def map_group_hw_index(self, expr, other) -> bool:
497+
return expr.axis == other.axis
498+
499+
map_local_hw_index = map_group_hw_index
500+
501+
def map_rule_argument(self, expr, other) -> bool:
502+
return expr.index == other.index
503+
504+
def map_resolved_function(self, expr, other) -> bool:
505+
return self.rec(expr.function, other.function)
506+
507+
def map_sub_array_ref(self, expr, other) -> bool:
508+
return (
509+
len(expr.swept_inames) == len(other.swept_inames)
510+
and self.rec(expr.subscript, other.subscript)
511+
and all(self.rec(iname, other_iname)
512+
for iname, other_iname in zip(
513+
expr.swept_inames,
514+
other.swept_inames))
515+
)
516+
517+
def map_tagged_variable(self, expr, other) -> bool:
518+
return (
519+
expr.name == other.name
520+
and all(tag == other_tag
521+
for tag, other_tag in zip(expr.tags, other.tags))
522+
)
523+
524+
def map_type_cast(self, expr, other) -> bool:
525+
return (
526+
expr.type == other.type
527+
and self.rec(expr.child, other.child))
528+
529+
def map_fortran_division(self, expr, other) -> bool:
530+
return self.map_quotient(expr, other)
531+
477532
# }}}
478533

479534

@@ -487,15 +542,18 @@ def stringifier(self):
487542
def make_stringifier(self, originating_stringifier=None):
488543
return StringifyMapper()
489544

545+
def make_equality_mapper(self):
546+
return EqualityMapper()
547+
490548

491549
class Literal(LoopyExpressionBase):
492550
"""A literal to be used during code generation.
493551
494552
.. note::
495553
496554
Only used in the output of
497-
:mod:`loopy.target.c.codegen.expression.ExpressionToCExpressionMapper` (and
498-
similar mappers). Not for use in Loopy source representation.
555+
:class:`loopy.target.c.codegen.expression.ExpressionToCExpressionMapper`
556+
(and similar mappers). Not for use in :mod:`loopy` source representation.
499557
"""
500558

501559
def __init__(self, s):
@@ -515,8 +573,8 @@ class ArrayLiteral(LoopyExpressionBase):
515573
.. note::
516574
517575
Only used in the output of
518-
:mod:`loopy.target.c.codegen.expression.ExpressionToCExpressionMapper` (and
519-
similar mappers). Not for use in Loopy source representation.
576+
:class:`loopy.target.c.codegen.expression.ExpressionToCExpressionMapper`
577+
(and similar mappers). Not for use in :mod:`loopy` source representation.
520578
"""
521579

522580
def __init__(self, children):
@@ -545,8 +603,8 @@ class GroupHardwareAxisIndex(HardwareAxisIndex):
545603
.. note::
546604
547605
Only used in the output of
548-
:mod:`loopy.target.c.expression.ExpressionToCExpressionMapper` (and
549-
similar mappers). Not for use in Loopy source representation.
606+
:class:`loopy.target.c.codegen.expression.ExpressionToCExpressionMapper`
607+
(and similar mappers). Not for use in :mod:`loopy` source representation.
550608
"""
551609
mapper_method = "map_group_hw_index"
552610

@@ -556,8 +614,8 @@ class LocalHardwareAxisIndex(HardwareAxisIndex):
556614
.. note::
557615
558616
Only used in the output of
559-
:mod:`loopy.target.c.expression.ExpressionToCExpressionMapper` (and
560-
similar mappers). Not for use in Loopy source representation.
617+
:class:`loopy.target.c.expression.ExpressionToCExpressionMapper` (and
618+
similar mappers). Not for use in :mod:`loopy` source representation.
561619
"""
562620
mapper_method = "map_local_hw_index"
563621

@@ -764,12 +822,6 @@ def __getinitargs__(self):
764822
def get_hash(self):
765823
return hash((self.__class__, self.operation, self.inames, self.expr))
766824

767-
def is_equal(self, other):
768-
return (other.__class__ == self.__class__
769-
and other.operation == self.operation
770-
and other.inames == self.inames
771-
and other.expr == self.expr)
772-
773825
@property
774826
def is_tuple_typed(self):
775827
return self.operation.arg_count > 1
@@ -967,14 +1019,6 @@ def __getinitargs__(self):
9671019
def get_hash(self):
9681020
return hash((self.__class__, self.swept_inames, self.subscript))
9691021

970-
def is_equal(self, other):
971-
"""
972-
Returns *True* iff the sub-array refs have identical expressions.
973-
"""
974-
return (other.__class__ == self.__class__
975-
and other.subscript == self.subscript
976-
and other.swept_inames == self.swept_inames)
977-
9781022
def make_stringifier(self, originating_stringifier=None):
9791023
return StringifyMapper()
9801024

requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ git+https://github.com/inducer/pytools.git#egg=pytools >= 2021.1
22
git+https://github.com/inducer/islpy.git#egg=islpy
33
git+https://github.com/inducer/cgen.git#egg=cgen
44
git+https://github.com/inducer/pyopencl.git#egg=pyopencl
5-
git+https://github.com/inducer/pymbolic.git#egg=pymbolic
5+
git+https://github.com/alexfikl/pymbolic.git@equality-mapper#egg=pymbolic
66
git+https://github.com/inducer/genpy.git#egg=genpy
77
git+https://github.com/inducer/codepy.git#egg=codepy
88

0 commit comments

Comments
 (0)