4040 CallbackMapper as CallbackMapperBase ,
4141 CSECachingMapperMixin ,
4242 )
43- from pymbolic .mapper .evaluator import \
44- EvaluationMapper as EvaluationMapperBase
45- from pymbolic .mapper .substitutor import \
46- SubstitutionMapper as SubstitutionMapperBase
47- from pymbolic .mapper .stringifier import \
48- StringifyMapper as StringifyMapperBase
49- from pymbolic .mapper .dependency import \
50- DependencyMapper as DependencyMapperBase
51- from pymbolic .mapper .coefficient import \
52- CoefficientCollector as CoefficientCollectorBase
53- from pymbolic .mapper .unifier import UnidirectionalUnifier \
54- as UnidirectionalUnifierBase
55- from pymbolic .mapper .constant_folder import \
56- ConstantFoldingMapper as ConstantFoldingMapperBase
43+ from pymbolic .mapper .equality import (
44+ EqualityMapper as EqualityMapperBase )
45+ from pymbolic .mapper .evaluator import (
46+ EvaluationMapper as EvaluationMapperBase )
47+ from pymbolic .mapper .substitutor import (
48+ SubstitutionMapper as SubstitutionMapperBase )
49+ from pymbolic .mapper .stringifier import (
50+ StringifyMapper as StringifyMapperBase )
51+ from pymbolic .mapper .dependency import (
52+ DependencyMapper as DependencyMapperBase )
53+ from pymbolic .mapper .coefficient import (
54+ CoefficientCollector as CoefficientCollectorBase )
55+ from pymbolic .mapper .unifier import (
56+ UnidirectionalUnifier as UnidirectionalUnifierBase )
57+ from pymbolic .mapper .constant_folder import (
58+ ConstantFoldingMapper as ConstantFoldingMapperBase )
5759
5860from pymbolic .parser import Parser as ParserBase
5961from loopy .diagnostic import LoopyError
60- from loopy .diagnostic import (ExpressionToAffineConversionError ,
61- UnableToDetermineAccessRangeError )
62+ from loopy .diagnostic import (
63+ ExpressionToAffineConversionError ,
64+ UnableToDetermineAccessRangeError )
6265
6366
6467import islpy as isl
@@ -114,8 +117,11 @@ def map_literal(self, expr, *args, **kwargs):
114117 return expr
115118
116119 def map_array_literal (self , expr , * args , ** kwargs ):
117- return type (expr )(tuple (self .rec (ch , * args , ** kwargs )
118- for ch in expr .children ))
120+ children = [self .rec (ch , * args , ** kwargs ) for ch in expr .children ]
121+ if all (ch is orig for ch , orig in zip (children , expr .children )):
122+ return expr
123+
124+ return type (expr )(tuple (children ))
119125
120126 def map_group_hw_index (self , expr , * args , ** kwargs ):
121127 return expr
@@ -474,6 +480,55 @@ def map_substitution(self, name, rule, arguments):
474480
475481 return self .rec (expr )
476482
483+
484+ class EqualityMapper (EqualityMapperBase ):
485+ def map_loopy_function_identifier (self , expr , other ) -> bool :
486+ return True
487+
488+ def map_reduction (self , expr , other ) -> bool :
489+ return (
490+ expr .operation == other .operation
491+ and expr .allow_simultaneous == other .allow_simultaneous
492+ and self .rec (expr .expr , other .expr )
493+ and all (iname == other_iname
494+ for iname , other_iname in zip (expr .inames , other .inames )))
495+
496+ def map_group_hw_index (self , expr , other ) -> bool :
497+ return expr .axis == other .axis
498+
499+ map_local_hw_index = map_group_hw_index
500+
501+ def map_rule_argument (self , expr , other ) -> bool :
502+ return expr .index == other .index
503+
504+ def map_resolved_function (self , expr , other ) -> bool :
505+ return self .rec (expr .function , other .function )
506+
507+ def map_sub_array_ref (self , expr , other ) -> bool :
508+ return (
509+ len (expr .swept_inames ) == len (other .swept_inames )
510+ and self .rec (expr .subscript , other .subscript )
511+ and all (self .rec (iname , other_iname )
512+ for iname , other_iname in zip (
513+ expr .swept_inames ,
514+ other .swept_inames ))
515+ )
516+
517+ def map_tagged_variable (self , expr , other ) -> bool :
518+ return (
519+ expr .name == other .name
520+ and all (tag == other_tag
521+ for tag , other_tag in zip (expr .tags , other .tags ))
522+ )
523+
524+ def map_type_cast (self , expr , other ) -> bool :
525+ return (
526+ expr .type == other .type
527+ and self .rec (expr .child , other .child ))
528+
529+ def map_fortran_division (self , expr , other ) -> bool :
530+ return self .map_quotient (expr , other )
531+
477532# }}}
478533
479534
@@ -487,15 +542,18 @@ def stringifier(self):
487542 def make_stringifier (self , originating_stringifier = None ):
488543 return StringifyMapper ()
489544
545+ def make_equality_mapper (self ):
546+ return EqualityMapper ()
547+
490548
491549class Literal (LoopyExpressionBase ):
492550 """A literal to be used during code generation.
493551
494552 .. note::
495553
496554 Only used in the output of
497- :mod :`loopy.target.c.codegen.expression.ExpressionToCExpressionMapper` (and
498- similar mappers). Not for use in Loopy source representation.
555+ :class :`loopy.target.c.codegen.expression.ExpressionToCExpressionMapper`
556+ (and similar mappers). Not for use in :mod:`loopy` source representation.
499557 """
500558
501559 def __init__ (self , s ):
@@ -515,8 +573,8 @@ class ArrayLiteral(LoopyExpressionBase):
515573 .. note::
516574
517575 Only used in the output of
518- :mod :`loopy.target.c.codegen.expression.ExpressionToCExpressionMapper` (and
519- similar mappers). Not for use in Loopy source representation.
576+ :class :`loopy.target.c.codegen.expression.ExpressionToCExpressionMapper`
577+ (and similar mappers). Not for use in :mod:`loopy` source representation.
520578 """
521579
522580 def __init__ (self , children ):
@@ -545,8 +603,8 @@ class GroupHardwareAxisIndex(HardwareAxisIndex):
545603 .. note::
546604
547605 Only used in the output of
548- :mod :`loopy.target.c.expression.ExpressionToCExpressionMapper` (and
549- similar mappers). Not for use in Loopy source representation.
606+ :class :`loopy.target.c.codegen. expression.ExpressionToCExpressionMapper`
607+ (and similar mappers). Not for use in :mod:`loopy` source representation.
550608 """
551609 mapper_method = "map_group_hw_index"
552610
@@ -556,8 +614,8 @@ class LocalHardwareAxisIndex(HardwareAxisIndex):
556614 .. note::
557615
558616 Only used in the output of
559- :mod :`loopy.target.c.expression.ExpressionToCExpressionMapper` (and
560- similar mappers). Not for use in Loopy source representation.
617+ :class :`loopy.target.c.expression.ExpressionToCExpressionMapper` (and
618+ similar mappers). Not for use in :mod:`loopy` source representation.
561619 """
562620 mapper_method = "map_local_hw_index"
563621
@@ -764,12 +822,6 @@ def __getinitargs__(self):
764822 def get_hash (self ):
765823 return hash ((self .__class__ , self .operation , self .inames , self .expr ))
766824
767- def is_equal (self , other ):
768- return (other .__class__ == self .__class__
769- and other .operation == self .operation
770- and other .inames == self .inames
771- and other .expr == self .expr )
772-
773825 @property
774826 def is_tuple_typed (self ):
775827 return self .operation .arg_count > 1
@@ -967,14 +1019,6 @@ def __getinitargs__(self):
9671019 def get_hash (self ):
9681020 return hash ((self .__class__ , self .swept_inames , self .subscript ))
9691021
970- def is_equal (self , other ):
971- """
972- Returns *True* iff the sub-array refs have identical expressions.
973- """
974- return (other .__class__ == self .__class__
975- and other .subscript == self .subscript
976- and other .swept_inames == self .swept_inames )
977-
9781022 def make_stringifier (self , originating_stringifier = None ):
9791023 return StringifyMapper ()
9801024
0 commit comments