From 37f2c1435166d5984838acb77c7d3acc066b72a0 Mon Sep 17 00:00:00 2001 From: Fabio Luporini Date: Thu, 17 Jul 2025 10:38:03 +0100 Subject: [PATCH 01/22] compiler: Add FunctionMap type --- devito/types/parallel.py | 18 ++++++++++++++ tests/test_iet.py | 53 ++++++++++++++++++++++++++++++++++++++-- 2 files changed, 69 insertions(+), 2 deletions(-) diff --git a/devito/types/parallel.py b/devito/types/parallel.py index ff9e1405a9..40da65db37 100644 --- a/devito/types/parallel.py +++ b/devito/types/parallel.py @@ -20,11 +20,13 @@ from devito.types.basic import Scalar, Symbol from devito.types.dimension import CustomDimension from devito.types.misc import Fence, VolatileInt +from devito.types.object import LocalObject __all__ = [ 'Barrier', 'DeviceID', 'DeviceRM', + 'FunctionMap', 'Lock', 'NPThreads', 'NThreads', @@ -384,3 +386,19 @@ def __init_finalize__(self, *args, **kwargs): kwargs['liveness'] = 'eager' super().__init_finalize__(*args, **kwargs) + + +class FunctionMap(LocalObject): + + """ + Wrap a Function in a LocalObject. + """ + + __rargs__ = ('name', 'tensor') + + def __init__(self, name, tensor, **kwargs): + super().__init__(name, **kwargs) + self.tensor = tensor + + def _hashable_content(self): + return super()._hashable_content() + (self.tensor,) diff --git a/tests/test_iet.py b/tests/test_iet.py index 7bc4f1e709..1361ed569c 100644 --- a/tests/test_iet.py +++ b/tests/test_iet.py @@ -17,10 +17,13 @@ from devito.passes.iet.engine import Graph from devito.passes.iet.languages.C import CDataManager from devito.symbolics import ( - FLOAT, Byref, Class, FieldFromComposite, InlineIf, Macro, String + FLOAT, Byref, Class, FieldFromComposite, InlineIf, ListInitializer, + Macro, SizeOf, String ) from devito.tools import CustomDtype, as_tuple, dtype_to_ctype -from devito.types import Array, CustomDimension, LocalObject, Pointer, Symbol +from devito.types import ( + Array, CustomDimension, FunctionMap, LocalObject, Pointer, Symbol +) @pytest.fixture @@ -299,6 +302,52 @@ def _C_free(self): }""" +def test_make_cuda_tensor_map(): + + class CUTensorMap(FunctionMap): + + dtype = CustomDtype('CUtensorMap') + + @property + def _C_init(self): + symsizes = list(reversed(self.tensor.symbolic_shape)) + sizeof_dtype = SizeOf(self.tensor.dmap._C_typedata) + + sizes = ListInitializer(symsizes) + strides = ListInitializer([ + np.prod(symsizes[:i])*sizeof_dtype for i in range(1, len(symsizes)) + ]) + + arguments = [ + Byref(self), + Macro('CU_TENSOR_MAP_DATA_TYPE_FLOAT32'), + 4, self.tensor.dmap, sizes, strides, + ] + call = Call('cuTensorMapEncodeTiled', arguments) + + return call + + grid = Grid(shape=(10, 10, 10)) + + u = TimeFunction(name='u', grid=grid) + + tmap = CUTensorMap('tmap', u) + + iet = Call('foo', tmap) + iet = ElementalFunction('foo', iet, parameters=()) + dm = CDataManager(sregistry=None) + iet = CDataManager.place_definitions.__wrapped__(dm, iet)[0] + + assert str(iet) == """\ +static void foo() +{ + CUtensorMap tmap; + cuTensorMapEncodeTiled(&(tmap),CU_TENSOR_MAP_DATA_TYPE_FLOAT32,4,d_u,{u_vec->size[3], u_vec->size[2], u_vec->size[1], u_vec->size[0]},{sizeof(float)*u_vec->size[3], sizeof(float)*u_vec->size[2]*u_vec->size[3], sizeof(float)*u_vec->size[1]*u_vec->size[2]*u_vec->size[3]}); + + foo(tmap); +}""" + + def test_cpp_local_object(): """ Test C++ support for LocalObjects. From fea9884024efe9c0660b2da4cf47831e4d0ff386 Mon Sep 17 00:00:00 2001 From: Fabio Luporini Date: Thu, 17 Jul 2025 11:57:02 +0100 Subject: [PATCH 02/22] compiler: Add ULONG to __all__ --- devito/symbolics/extended_dtypes.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/devito/symbolics/extended_dtypes.py b/devito/symbolics/extended_dtypes.py index 29bab821ca..5a2bcc47f7 100644 --- a/devito/symbolics/extended_dtypes.py +++ b/devito/symbolics/extended_dtypes.py @@ -10,8 +10,8 @@ from devito.tools.dtypes_lowering import dtype_mapper __all__ = ['cast', 'CustomType', 'limits_mapper', 'INT', 'FLOAT', 'BaseCast', # noqa - 'DOUBLE', 'VOID', 'NoDeclStruct', 'c_complex', 'c_double_complex', - 'LONG'] + 'DOUBLE', 'VOID', 'LONG', 'ULONG', 'NoDeclStruct', 'c_complex', + 'c_double_complex'] limits_mapper = { From cb6bad394ab418ae5b2b4feebbf71970f01e65f6 Mon Sep 17 00:00:00 2001 From: Fabio Luporini Date: Thu, 17 Jul 2025 12:34:40 +0100 Subject: [PATCH 03/22] compiler: Improve lowering of LocalObjects --- devito/passes/iet/definitions.py | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/devito/passes/iet/definitions.py b/devito/passes/iet/definitions.py index 29cc8c9787..1d5b0a69d4 100644 --- a/devito/passes/iet/definitions.py +++ b/devito/passes/iet/definitions.py @@ -98,14 +98,25 @@ def _alloc_object_on_low_lat_mem(self, site, obj, storage): """ decl = Definition(obj) - definition = (decl, obj._C_init) if obj._C_init else (decl) + init = obj._C_init + if not init: + definition = decl + efuncs = () + elif init.is_Callable: + definition = Call(init.name, init.parameters, + retobj=obj if init.retval else None) + efuncs = (init,) + else: + definition = (decl, init) + efuncs = () frees = obj._C_free if obj.free_symbols - {obj}: - storage.update(obj, site, objs=definition, frees=frees) + storage.update(obj, site, objs=definition, efuncs=efuncs, frees=frees) else: - storage.update(obj, site, standalones=definition, frees=frees) + storage.update(obj, site, standalones=definition, efuncs=efuncs, + frees=frees) def _alloc_array_on_low_lat_mem(self, site, obj, storage): """ From 02a9f9fa3a62b1e66f45b73b465baf045d3569b5 Mon Sep 17 00:00:00 2001 From: Fabio Luporini Date: Tue, 22 Jul 2025 14:18:17 +0100 Subject: [PATCH 04/22] compiler: Add LocalObject._mem_shared --- devito/types/object.py | 19 ++++++++++++------- tests/test_iet.py | 2 +- 2 files changed, 13 insertions(+), 8 deletions(-) diff --git a/devito/types/object.py b/devito/types/object.py index 4c30e6f8ff..dd6ed0d7a4 100644 --- a/devito/types/object.py +++ b/devito/types/object.py @@ -177,10 +177,10 @@ class LocalObject(AbstractObject, LocalType): """ __rargs__ = ('name',) - __rkwargs__ = ('cargs', 'initvalue', 'liveness', 'is_global') + __rkwargs__ = ('cargs', 'initvalue', 'liveness', 'scope') def __init__(self, name, cargs=None, initvalue=None, liveness='lazy', - is_global=False, **kwargs): + scope='stack', **kwargs): self.name = name self.cargs = as_tuple(cargs) @@ -192,16 +192,17 @@ def __init__(self, name, cargs=None, initvalue=None, liveness='lazy', assert liveness in ['eager', 'lazy'] self._liveness = liveness - self._is_global = is_global + assert scope in ['stack', 'shared', 'global'] + self._scope = scope def _hashable_content(self): return (super()._hashable_content() + self.cargs + - (self.initvalue, self.liveness, self.is_global)) + (self.initvalue, self.liveness, self.scope)) @property - def is_global(self): - return self._is_global + def scope(self): + return self._scope @property def free_symbols(self): @@ -235,6 +236,10 @@ def _C_free(self): """ return None + @property + def _mem_shared(self): + return self._scope == 'shared' + @property def _mem_global(self): - return self._is_global + return self._scope == 'global' diff --git a/tests/test_iet.py b/tests/test_iet.py index 1361ed569c..425b714e2e 100644 --- a/tests/test_iet.py +++ b/tests/test_iet.py @@ -360,7 +360,7 @@ class MyObject(LocalObject): lo0 = MyObject('obj0') # Globally-scoped objects must not be declared in the function body - lo1 = MyObject('obj1', is_global=True) + lo1 = MyObject('obj1', scope='global') # A LocalObject using both a template and a modifier class SpecialObject(LocalObject): From a530eefb3450180f162594cbec603ff5ae02cb75 Mon Sep 17 00:00:00 2001 From: Fabio Luporini Date: Mon, 28 Jul 2025 15:56:37 +0100 Subject: [PATCH 05/22] compiler: Add LocalType._C_tag --- devito/ir/iet/visitors.py | 3 +++ devito/types/basic.py | 13 +++++++++++-- 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/devito/ir/iet/visitors.py b/devito/ir/iet/visitors.py index b450960354..fb73d25004 100644 --- a/devito/ir/iet/visitors.py +++ b/devito/ir/iet/visitors.py @@ -321,6 +321,9 @@ def _gen_value(self, obj, mode=1, masked=()): qualifiers = [v for k, v in self._qualifiers_mapper.items() if getattr(obj.function, k, False) and v not in masked] + if obj.is_LocalObject and mode == 2: + qualifiers.extend(as_tuple(obj._C_tag)) + if (obj._mem_stack or obj._mem_constant) and mode == 1: strtype = self.ccode(obj._C_typedata) strshape = ''.join(f'[{self.ccode(i)}]' for i in obj.symbolic_shape) diff --git a/devito/types/basic.py b/devito/types/basic.py index a3b8af4533..00747c8114 100644 --- a/devito/types/basic.py +++ b/devito/types/basic.py @@ -1937,8 +1937,17 @@ def _mem_internal_lazy(self): return self._liveness == 'lazy' """ - A modifier added to the subclass C declaration when it appears - in a function signature. For example, a subclass might define `_C_modifier = '&'` + A modifier added to the declaration of the LocalType when it appears in a + function signature. For example, a subclass might define `_C_modifier = '&'` to impose pass-by-reference semantics. """ _C_modifier = None + + """ + One or more optional keywords added to the declaration of the LocalType + in between the type and the variable name when it appears in a function + signature. For example, some languages support these to modify the way + the compiler generates code for passing the parameter and how the + runtime accesses it. + """ + _C_tag = None From 873f8a0fcadda8e1fe09f5885d4cd4172b4617ff Mon Sep 17 00:00:00 2001 From: Fabio Luporini Date: Tue, 29 Jul 2025 11:48:41 +0100 Subject: [PATCH 06/22] compiler: Move and enhance FunctionMap --- devito/types/misc.py | 24 ++++++++++++++++++++++++ devito/types/parallel.py | 18 ------------------ tests/test_iet.py | 9 +++++---- 3 files changed, 29 insertions(+), 22 deletions(-) diff --git a/devito/types/misc.py b/devito/types/misc.py index 571197b717..9c06cc5a70 100644 --- a/devito/types/misc.py +++ b/devito/types/misc.py @@ -381,6 +381,30 @@ def closing(self): """ +class FunctionMap(LocalObject): + + """ + Wrap a Function in a LocalObject. + """ + + __rargs__ = ('name', 'tensor') + + def __init__(self, name, tensor, **kwargs): + super().__init__(name, **kwargs) + self.tensor = tensor + + def _hashable_content(self): + return super()._hashable_content() + (self.tensor,) + + @property + def free_symbols(self): + """ + The free symbols of a FunctionMap are the free symbols of the + underlying Function. + """ + return super().free_symbols | {self.tensor} + + # *** C/CXX support types size_t = CustomDtype('size_t') diff --git a/devito/types/parallel.py b/devito/types/parallel.py index 40da65db37..ff9e1405a9 100644 --- a/devito/types/parallel.py +++ b/devito/types/parallel.py @@ -20,13 +20,11 @@ from devito.types.basic import Scalar, Symbol from devito.types.dimension import CustomDimension from devito.types.misc import Fence, VolatileInt -from devito.types.object import LocalObject __all__ = [ 'Barrier', 'DeviceID', 'DeviceRM', - 'FunctionMap', 'Lock', 'NPThreads', 'NThreads', @@ -386,19 +384,3 @@ def __init_finalize__(self, *args, **kwargs): kwargs['liveness'] = 'eager' super().__init_finalize__(*args, **kwargs) - - -class FunctionMap(LocalObject): - - """ - Wrap a Function in a LocalObject. - """ - - __rargs__ = ('name', 'tensor') - - def __init__(self, name, tensor, **kwargs): - super().__init__(name, **kwargs) - self.tensor = tensor - - def _hashable_content(self): - return super()._hashable_content() + (self.tensor,) diff --git a/tests/test_iet.py b/tests/test_iet.py index 425b714e2e..ac58f9146b 100644 --- a/tests/test_iet.py +++ b/tests/test_iet.py @@ -22,8 +22,9 @@ ) from devito.tools import CustomDtype, as_tuple, dtype_to_ctype from devito.types import ( - Array, CustomDimension, FunctionMap, LocalObject, Pointer, Symbol + Array, CustomDimension, LocalObject, Pointer, Symbol ) +from devito.types.misc import FunctionMap @pytest.fixture @@ -322,7 +323,7 @@ def _C_init(self): Byref(self), Macro('CU_TENSOR_MAP_DATA_TYPE_FLOAT32'), 4, self.tensor.dmap, sizes, strides, - ] + ] call = Call('cuTensorMapEncodeTiled', arguments) return call @@ -342,10 +343,10 @@ def _C_init(self): static void foo() { CUtensorMap tmap; - cuTensorMapEncodeTiled(&(tmap),CU_TENSOR_MAP_DATA_TYPE_FLOAT32,4,d_u,{u_vec->size[3], u_vec->size[2], u_vec->size[1], u_vec->size[0]},{sizeof(float)*u_vec->size[3], sizeof(float)*u_vec->size[2]*u_vec->size[3], sizeof(float)*u_vec->size[1]*u_vec->size[2]*u_vec->size[3]}); + cuTensorMapEncodeTiled(&tmap,CU_TENSOR_MAP_DATA_TYPE_FLOAT32,4,d_u,{u_vec->size[3], u_vec->size[2], u_vec->size[1], u_vec->size[0]},{sizeof(float)*u_vec->size[3], sizeof(float)*u_vec->size[2]*u_vec->size[3], sizeof(float)*u_vec->size[1]*u_vec->size[2]*u_vec->size[3]}); foo(tmap); -}""" +}""" # noqa def test_cpp_local_object(): From ddb877e19183ddba5e5616f27865b6a0b45f1ef0 Mon Sep 17 00:00:00 2001 From: Fabio Luporini Date: Tue, 29 Jul 2025 15:46:59 +0100 Subject: [PATCH 07/22] arch: async-loads -> async-pipe --- devito/arch/archinfo.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/devito/arch/archinfo.py b/devito/arch/archinfo.py index 88d1b3304c..1a17ba6da1 100644 --- a/devito/arch/archinfo.py +++ b/devito/arch/archinfo.py @@ -1139,7 +1139,7 @@ def supports(self, query, language=None): warning(f"Couldn't establish if `query={query}` is supported on this " "system. Assuming it is not.") return False - elif query == 'async-loads' and cc >= 80: + elif query == 'async-pipe' and cc >= 80: # Asynchronous pipeline loads -- introduced in Ampere return True elif query in ('tma', 'thread-block-cluster') and cc >= 90: # noqa: SIM103 @@ -1156,7 +1156,7 @@ class Volta(NvidiaDevice): class Ampere(Volta): def supports(self, query, language=None): - if query == 'async-loads': + if query == 'async-pipe': return True else: return super().supports(query, language) From 01255cfb9948aa3c5a15ea5e0e264d14359642ec Mon Sep 17 00:00:00 2001 From: Fabio Luporini Date: Fri, 8 Aug 2025 14:16:32 +0100 Subject: [PATCH 08/22] compiler: Fix IREq.__repr__ --- devito/ir/equations/equation.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/devito/ir/equations/equation.py b/devito/ir/equations/equation.py index 29945903a9..1242a84fc5 100644 --- a/devito/ir/equations/equation.py +++ b/devito/ir/equations/equation.py @@ -92,9 +92,11 @@ def __repr__(self): if not self.is_Reduction: return super().__repr__() elif self.operation is OpInc: - return f'{self.lhs} += {self.rhs}' + return f'Inc({self.lhs}, {self.rhs})' else: - return f'{self.lhs} = {self.operation}({self.rhs})' + return f'Eq({self.lhs}, {self.operation}({self.rhs}))' + + __str__ = __repr__ # Pickling support __reduce_ex__ = Pickable.__reduce_ex__ From cc4a53f0aafad3947d8edfe36e0244e985c7485b Mon Sep 17 00:00:00 2001 From: Fabio Luporini Date: Fri, 8 Aug 2025 14:35:13 +0100 Subject: [PATCH 09/22] compiler: Generalize ideriv lowering --- devito/passes/clusters/derivatives.py | 44 ++++++++++++++++++--------- 1 file changed, 29 insertions(+), 15 deletions(-) diff --git a/devito/passes/clusters/derivatives.py b/devito/passes/clusters/derivatives.py index 940d241343..fe6b5ed44b 100644 --- a/devito/passes/clusters/derivatives.py +++ b/devito/passes/clusters/derivatives.py @@ -3,7 +3,7 @@ import numpy as np from sympy import S -from devito.finite_differences import IndexDerivative +from devito.finite_differences import IndexDerivative, Weights from devito.ir import Backward, Forward, Interval, IterationSpace, Queue from devito.passes.clusters.misc import fuse from devito.symbolics import BasicWrapperMixin, reuse_if_untouched, uxreplace @@ -91,17 +91,39 @@ def _core(expr, c, ispace, weights, reusables, mapper, **kwargs): @_core.register(Symbol) -@_core.register(Indexed) @_core.register(BasicWrapperMixin) def _(expr, c, ispace, weights, reusables, mapper, **kwargs): return expr, [] +@_core.register(Indexed) +def _(expr, c, ispace, weights, reusables, mapper, **kwargs): + if not isinstance(expr.function, Weights): + return expr, [] + + # Lower or reuse a previously lowered Weights array + sregistry = kwargs['sregistry'] + subs_user = kwargs['subs'] + + w0 = expr.function + k = tuple(w0.weights) + try: + w = weights[k] + except KeyError: + name = sregistry.make_name(prefix='w') + dtype = infer_dtype([w0.dtype, c.dtype]) # At least np.float32 + initvalue = tuple(i.subs(subs_user) for i in k) + w = weights[k] = w0._rebuild(name=name, dtype=dtype, initvalue=initvalue) + + rebuilt = expr._subs(w0.indexed, w.indexed) + + return rebuilt, [] + + @_core.register(IndexDerivative) def _(expr, c, ispace, weights, reusables, mapper, **kwargs): sregistry = kwargs['sregistry'] options = kwargs['options'] - subs_user = kwargs['subs'] try: cbk0 = deriv_schedule_registry[options['deriv-schedule']] @@ -114,18 +136,10 @@ def _(expr, c, ispace, weights, reusables, mapper, **kwargs): # Create the concrete Weights array, or reuse an already existing one # if possible - name = sregistry.make_name(prefix='w') - w0 = ideriv.weights.function - dtype = infer_dtype([w0.dtype, c.dtype]) # At least np.float32 - k = tuple(w0.weights) - try: - w = weights[k] - except KeyError: - initvalue = tuple(i.subs(subs_user) for i in k) - w = weights[k] = w0._rebuild(name=name, dtype=dtype, initvalue=initvalue) + w, _ = _core(ideriv.weights, c, ispace, weights, reusables, mapper, **kwargs) # Replace the abstract Weights array with the concrete one - subs = {w0.indexed: w.indexed} + subs = {ideriv.weights.base: w.base} init = uxreplace(init, subs) ideriv = uxreplace(ideriv, subs) @@ -155,10 +169,10 @@ def _(expr, c, ispace, weights, reusables, mapper, **kwargs): # NOTE: created before recurring so that we ultimately get a sound ordering try: s = reusables.pop() - assert np.can_cast(s.dtype, dtype) + assert np.can_cast(s.dtype, w.dtype) except KeyError: name = sregistry.make_name(prefix='r') - s = Symbol(name=name, dtype=dtype) + s = Symbol(name=name, dtype=w.dtype) # Go inside `expr` and recursively lower any nested IndexDerivatives expr, processed = _core(expr, c, ispace1, weights, reusables, mapper, **kwargs) From d6fcbd999d850a0e8fb49b3af8bfed87e28f40c6 Mon Sep 17 00:00:00 2001 From: Fabio Luporini Date: Mon, 11 Aug 2025 11:21:12 +0100 Subject: [PATCH 10/22] compiler: Avoid CSE across Reserved keywords --- devito/passes/clusters/cse.py | 4 ++-- devito/symbolics/extended_sympy.py | 24 +++++++++++++++++++----- devito/symbolics/inspection.py | 4 ++-- 3 files changed, 23 insertions(+), 9 deletions(-) diff --git a/devito/passes/clusters/cse.py b/devito/passes/clusters/cse.py index d4d7f0a8b8..ad12b1dd3a 100644 --- a/devito/passes/clusters/cse.py +++ b/devito/passes/clusters/cse.py @@ -13,9 +13,8 @@ from devito.finite_differences.differentiable import IndexDerivative from devito.ir import Cluster, Scope, cluster_pass -from devito.symbolics import estimate_cost, q_leaf, q_terminal +from devito.symbolics import Reserved, estimate_cost, q_leaf, q_terminal, search from devito.symbolics.manipulation import _uxreplace -from devito.symbolics.search import search from devito.tools import DAG, as_list, as_tuple, extract_dtype, frozendict from devito.types import Eq, Symbol, Temp @@ -399,6 +398,7 @@ def _(expr): @_catch.register(Indexed) @_catch.register(Symbol) +@_catch.register(Reserved) def _(expr): """ Handler for objects preventing CSE to propagate through their arguments. diff --git a/devito/symbolics/extended_sympy.py b/devito/symbolics/extended_sympy.py index 2c352d50d4..c00526bb88 100644 --- a/devito/symbolics/extended_sympy.py +++ b/devito/symbolics/extended_sympy.py @@ -19,12 +19,12 @@ from devito.types.basic import Basic __all__ = ['CondEq', 'CondNe', 'BitwiseNot', 'BitwiseXor', 'BitwiseAnd', # noqa - 'LeftShift', 'RightShift', 'IntDiv', 'CallFromPointer', + 'LeftShift', 'RightShift', 'IntDiv', 'CallFromPointer', 'CallFromComposite', 'FieldFromPointer', 'FieldFromComposite', 'ListInitializer', 'Byref', 'IndexedPointer', 'Cast', 'DefFunction', - 'MathFunction', 'InlineIf', 'ReservedWord', 'Keyword', 'String', - 'Macro', 'Class', 'MacroArgument', 'Deref', 'Namespace', 'Rvalue', - 'Null', 'SizeOf', 'rfunc', 'BasicWrapperMixin', 'ValueLimit', + 'MathFunction', 'InlineIf', 'Reserved', 'ReservedWord', 'Keyword', + 'String', 'Macro', 'Class', 'MacroArgument', 'Deref', 'Namespace', + 'Rvalue', 'Null', 'SizeOf', 'rfunc', 'BasicWrapperMixin', 'ValueLimit', 'VectorAccess'] @@ -533,7 +533,21 @@ def __str__(self): __reduce_ex__ = Pickable.__reduce_ex__ -class ReservedWord(sympy.Atom, Pickable): +class Reserved(Pickable): + + """ + A base class for all reserved words used throughout the lowering process, + including the final stage of code generation itself. + + Reserved objects have the following properties: + + * `estimate_cost(o) = 0`, where `o` is an instance of Reserved + """ + + pass + + +class ReservedWord(sympy.Atom, Reserved): """ A `ReservedWord` carries a value that has special meaning in the diff --git a/devito/symbolics/inspection.py b/devito/symbolics/inspection.py index 3147118de1..3dd0bc55bb 100644 --- a/devito/symbolics/inspection.py +++ b/devito/symbolics/inspection.py @@ -11,7 +11,7 @@ from devito.logger import warning from devito.symbolics.extended_dtypes import INT from devito.symbolics.extended_sympy import ( - CallFromPointer, Cast, DefFunction, ReservedWord + CallFromPointer, Cast, DefFunction, Reserved ) from devito.symbolics.queries import q_routine from devito.tools import as_tuple, is_integer, prod @@ -179,7 +179,7 @@ def _(expr, estimate, seen): @_estimate_cost.register(ImaginaryUnit) @_estimate_cost.register(Number) -@_estimate_cost.register(ReservedWord) +@_estimate_cost.register(Reserved) def _(expr, estimate, seen): return 0, False From 939a8d8b95ea4e3d12794482a8688408dc58408b Mon Sep 17 00:00:00 2001 From: Fabio Luporini Date: Mon, 11 Aug 2025 15:29:04 +0100 Subject: [PATCH 11/22] compiler: Introduce Terminal mixin for SymPy subclasses --- devito/symbolics/extended_sympy.py | 25 ++++++++++++++++++------- devito/symbolics/queries.py | 14 ++++---------- 2 files changed, 22 insertions(+), 17 deletions(-) diff --git a/devito/symbolics/extended_sympy.py b/devito/symbolics/extended_sympy.py index c00526bb88..452cab7309 100644 --- a/devito/symbolics/extended_sympy.py +++ b/devito/symbolics/extended_sympy.py @@ -19,7 +19,7 @@ from devito.types.basic import Basic __all__ = ['CondEq', 'CondNe', 'BitwiseNot', 'BitwiseXor', 'BitwiseAnd', # noqa - 'LeftShift', 'RightShift', 'IntDiv', 'CallFromPointer', + 'LeftShift', 'RightShift', 'IntDiv', 'Terminal', 'CallFromPointer', 'CallFromComposite', 'FieldFromPointer', 'FieldFromComposite', 'ListInitializer', 'Byref', 'IndexedPointer', 'Cast', 'DefFunction', 'MathFunction', 'InlineIf', 'Reserved', 'ReservedWord', 'Keyword', @@ -148,6 +148,17 @@ def __mul__(self, other): return super().__mul__(other) +class Terminal: + + """ + Abstract base class for all terminal objects, that is, those objects + collected by `retrieve_terminals` in addition to all other SymPy atoms + such as `Symbol`, `Number`, etc. + """ + + pass + + class BasicWrapperMixin: """ @@ -189,7 +200,7 @@ def _sympystr(self, printer): return str(self) -class CallFromPointer(sympy.Expr, Pickable, BasicWrapperMixin): +class CallFromPointer(Expr, Pickable, BasicWrapperMixin, Terminal): """ Symbolic representation of the C notation ``pointer->call(params)``. @@ -257,7 +268,7 @@ def free_symbols(self): __reduce_ex__ = Pickable.__reduce_ex__ -class CallFromComposite(CallFromPointer, Pickable): +class CallFromComposite(CallFromPointer): """ Symbolic representation of the C notation ``composite.call(params)``. @@ -270,7 +281,7 @@ def __str__(self): __repr__ = __str__ -class FieldFromPointer(CallFromPointer, Pickable): +class FieldFromPointer(CallFromPointer): """ Symbolic representation of the C notation ``pointer->field``. @@ -291,7 +302,7 @@ def field(self): __repr__ = __str__ -class FieldFromComposite(CallFromPointer, Pickable): +class FieldFromComposite(CallFromPointer): """ Symbolic representation of the C notation ``composite.field``, @@ -353,7 +364,7 @@ def is_numeric(self): __reduce_ex__ = Pickable.__reduce_ex__ -class UnaryOp(sympy.Expr, Pickable, BasicWrapperMixin): +class UnaryOp(Expr, Pickable, BasicWrapperMixin): """ Symbolic representation of a unary C operator. @@ -486,7 +497,7 @@ def __str__(self): return f"{self._op}{self.base}" -class IndexedPointer(sympy.Expr, Pickable, BasicWrapperMixin): +class IndexedPointer(Expr, Pickable, BasicWrapperMixin, Terminal): """ Symbolic representation of the C notation ``symbol[...]`` diff --git a/devito/symbolics/queries.py b/devito/symbolics/queries.py index c8624c3578..4bf3e38c01 100644 --- a/devito/symbolics/queries.py +++ b/devito/symbolics/queries.py @@ -1,8 +1,6 @@ from sympy import Eq, IndexedBase, Mod, S, diff, nan -from devito.symbolics.extended_sympy import ( - FieldFromComposite, FieldFromPointer, IndexedPointer, IntDiv -) +from devito.symbolics.extended_sympy import IntDiv, Terminal from devito.tools import as_tuple, is_integer from devito.types.array import ComponentAccess from devito.types.basic import AbstractFunction @@ -32,13 +30,9 @@ ] -# The following SymPy objects are considered tree leaves: -# -# * Number -# * Symbol -# * Indexed -extra_leaves = (FieldFromPointer, FieldFromComposite, IndexedBase, AbstractObject, - IndexedPointer) +# The following SymPy objects are considered tree leaves in addition to the classic +# SymPy atoms such as Number, Symbol, Indexed, etc +extra_leaves = (IndexedBase, AbstractObject, Terminal) def q_symbol(expr): From 80a6c12b8f80bc065299af8e51aae44d5d94a07b Mon Sep 17 00:00:00 2001 From: Fabio Luporini Date: Tue, 12 Aug 2025 11:38:13 +0100 Subject: [PATCH 12/22] compiler: Pass ctx down to _map_function_on_high_bw_mem --- devito/passes/iet/definitions.py | 52 ++++++++++++++++---------------- 1 file changed, 26 insertions(+), 26 deletions(-) diff --git a/devito/passes/iet/definitions.py b/devito/passes/iet/definitions.py index 1d5b0a69d4..59d2b97347 100644 --- a/devito/passes/iet/definitions.py +++ b/devito/passes/iet/definitions.py @@ -563,7 +563,7 @@ class DeviceAwareDataManager(DataManager): def __init__(self, options=None, **kwargs): self.gpu_fit = options['gpu-fit'] self.gpu_create = options['gpu-create'] - self.pmode = options.get('place-transfers') + self.gpu_place_transfers = options.get('place-transfers') super().__init__(**kwargs) @@ -596,7 +596,8 @@ def _map_array_on_high_bw_mem(self, site, obj, storage): storage.update(obj, site, maps=mmap, unmaps=unmap) - def _map_function_on_high_bw_mem(self, site, obj, storage, devicerm, read_only=False): + def _map_function_on_high_bw_mem(self, site, obj, storage, devicerm, + read_only=False, **kwargs): """ Map a Function already defined in the host memory in to the device high bandwidth memory. @@ -629,42 +630,41 @@ def _map_function_on_high_bw_mem(self, site, obj, storage, devicerm, read_only=F storage.update(obj, site, maps=mmap, unmaps=unmap, efuncs=efuncs) @iet_pass - def place_transfers(self, iet, data_movs=None, **kwargs): + def place_transfers(self, iet, data_movs=None, ctx=None, **kwargs): """ Create a new IET with host-device data transfers. This requires mapping symbols to the suitable memory spaces. """ - if not self.pmode: + if not self.gpu_place_transfers: return iet, {} - @singledispatch - def _place_transfers(iet, data_movs): + if not isinstance(iet, EntryFunction): return iet, {} - @_place_transfers.register(EntryFunction) - def _(iet, data_movs): - reads, writes = data_movs + reads, writes = data_movs - # Special symbol which gives user code control over data deallocations - devicerm = DeviceRM() + # Special symbol which gives user code control over data deallocations + devicerm = DeviceRM() - storage = Storage() - for i in filter_sorted(writes): - if i.is_Array: - self._map_array_on_high_bw_mem(iet, i, storage) - else: - self._map_function_on_high_bw_mem(iet, i, storage, devicerm) - for i in filter_sorted(reads - writes): - if i.is_Array: - self._map_array_on_high_bw_mem(iet, i, storage) - else: - self._map_function_on_high_bw_mem(iet, i, storage, devicerm, True) - - iet, efuncs = self._inject_definitions(iet, storage) + storage = Storage() + for i in filter_sorted(writes): + if i.is_Array: + self._map_array_on_high_bw_mem(iet, i, storage) + else: + self._map_function_on_high_bw_mem( + iet, i, storage, devicerm, ctx=ctx + ) + for i in filter_sorted(reads - writes): + if i.is_Array: + self._map_array_on_high_bw_mem(iet, i, storage) + else: + self._map_function_on_high_bw_mem( + iet, i, storage, devicerm, read_only=True, ctx=ctx + ) - return iet, {'efuncs': efuncs} + iet, efuncs = self._inject_definitions(iet, storage) - return _place_transfers(iet, data_movs=data_movs) + return iet, {'efuncs': efuncs} @iet_pass def place_devptr(self, iet, **kwargs): From 41b8af866b282a158ea38c847b072d7fe82a112f Mon Sep 17 00:00:00 2001 From: Fabio Luporini Date: Tue, 12 Aug 2025 15:30:20 +0100 Subject: [PATCH 13/22] compiler: Enhance _alloc_object_on_low_lat_mem --- devito/passes/iet/definitions.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/devito/passes/iet/definitions.py b/devito/passes/iet/definitions.py index 59d2b97347..102372d618 100644 --- a/devito/passes/iet/definitions.py +++ b/devito/passes/iet/definitions.py @@ -102,6 +102,10 @@ def _alloc_object_on_low_lat_mem(self, site, obj, storage): if not init: definition = decl efuncs = () + elif isinstance(init, (list, tuple)): + assert len(init) == 2, "Expected (efunc, call)" + init, definition = init + efuncs = (init,) elif init.is_Callable: definition = Call(init.name, init.parameters, retobj=obj if init.retval else None) From c8458710fbc715d6624fa6fe50d1c213ab20aa52 Mon Sep 17 00:00:00 2001 From: Fabio Luporini Date: Tue, 12 Aug 2025 15:30:49 +0100 Subject: [PATCH 14/22] compiler: Fix abstract_object(Array) --- devito/passes/iet/engine.py | 31 ++++++++++++++++++++++++------- devito/types/array.py | 2 ++ devito/types/misc.py | 1 + devito/types/parallel.py | 2 ++ 4 files changed, 29 insertions(+), 7 deletions(-) diff --git a/devito/passes/iet/engine.py b/devito/passes/iet/engine.py index 5a383332e3..08d7c83aad 100644 --- a/devito/passes/iet/engine.py +++ b/devito/passes/iet/engine.py @@ -17,9 +17,9 @@ from devito.symbolics import FieldFromComposite, FieldFromPointer, IndexedPointer, search from devito.tools import DAG, as_tuple, filter_ordered, sorted_priority, timed_pass from devito.types import ( - Array, Bundle, ComponentAccess, CompositeObject, IncrDimension, Indirection, Lock, - ModuloDimension, NPThreads, NThreadsBase, Pointer, SharedData, Symbol, Temp, - ThreadArray, Wildcard + Array, Bundle, ComponentAccess, CompositeObject, FunctionMap, IncrDimension, + Indirection, ModuloDimension, NPThreads, NThreadsBase, Pointer, SharedData, + Symbol, Temp, ThreadArray, Wildcard ) from devito.types.args import ArgProvider from devito.types.dense import DiscreteFunction @@ -550,12 +550,19 @@ def _(i, mapper, sregistry): @abstract_object.register(Array) def _(i, mapper, sregistry): - if isinstance(i, Lock): - name = sregistry.make_name(prefix='lock') + name = sregistry.make_name(prefix=i._symbol_prefix) + + if i.initvalue is not None: + initvalue = [] + for v in i.initvalue: + try: + initvalue.append(v.xreplace(mapper)) + except AttributeError: + initvalue.append(v) else: - name = sregistry.make_name(prefix='a') + initvalue = None - v = i._rebuild(name=name, alias=True) + v = i._rebuild(name=name, initvalue=initvalue, alias=True) mapper.update({ i: v, @@ -662,6 +669,16 @@ def _(i, mapper, sregistry): mapper[i] = i._rebuild(name=sregistry.make_name(prefix='ptr')) +@abstract_object.register(FunctionMap) +def _(i, mapper, sregistry): + name = sregistry.make_name(prefix=i._symbol_prefix) + tensor = mapper.get(i.tensor, i.tensor) + + v = i._rebuild(name, tensor) + + mapper[i] = v + + @abstract_object.register(NPThreads) def _(i, mapper, sregistry): mapper[i] = i._rebuild(name=sregistry.make_name(prefix='npthreads')) diff --git a/devito/types/array.py b/devito/types/array.py index c48425e33d..bd7747d40d 100644 --- a/devito/types/array.py +++ b/devito/types/array.py @@ -134,6 +134,8 @@ class Array(ArrayBasic): is_Array = True + _symbol_prefix = 'a' + __rkwargs__ = (ArrayBasic.__rkwargs__ + ('dimensions', 'scope', 'initvalue')) diff --git a/devito/types/misc.py b/devito/types/misc.py index 9c06cc5a70..40679bb9b6 100644 --- a/devito/types/misc.py +++ b/devito/types/misc.py @@ -19,6 +19,7 @@ 'CriticalRegion', 'FIndexed', 'Fence', + 'FunctionMap', 'Global', 'Hyperplane', 'Indirection', diff --git a/devito/types/parallel.py b/devito/types/parallel.py index ff9e1405a9..8300d0e80c 100644 --- a/devito/types/parallel.py +++ b/devito/types/parallel.py @@ -245,6 +245,8 @@ class Lock(Array): is_volatile = True + _symbol_prefix = 'lock' + # Not a performance-sensitive object _data_alignment = False From 8dade8050b9e9b9904ab8971c73db6b9eb11dd8a Mon Sep 17 00:00:00 2001 From: Fabio Luporini Date: Wed, 13 Aug 2025 11:44:33 +0100 Subject: [PATCH 15/22] compiler: Avoid spurious items in sub_iters and dirs --- devito/ir/support/space.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/devito/ir/support/space.py b/devito/ir/support/space.py index 0fcaf8f423..a1c0f7b7a8 100644 --- a/devito/ir/support/space.py +++ b/devito/ir/support/space.py @@ -789,17 +789,19 @@ def __init__(self, intervals, sub_iterators=None, directions=None): super().__init__(intervals) # Normalize sub-iterators - sub_iterators = dict([(k, tuple(filter_ordered(as_tuple(v)))) - for k, v in (sub_iterators or {}).items()]) + sub_iterators = sub_iterators or {} + sub_iterators = {d: tuple(filter_ordered(as_tuple(v))) + for d, v in sub_iterators.items() if d in self.intervals} sub_iterators.update({i.dim: () for i in self.intervals if i.dim not in sub_iterators}) self._sub_iterators = frozendict(sub_iterators) # Normalize directions - if directions is None: - self._directions = frozendict([(i.dim, Any) for i in self.intervals]) - else: - self._directions = frozendict(directions) + directions = directions or {} + directions = {d: v for d, v in directions.items() if d in self.intervals} + directions.update({i.dim: Any for i in self.intervals + if i.dim not in directions}) + self._directions = frozendict(directions) def __repr__(self): ret = ', '.join([f"{repr(i)}{repr(self.directions[i.dim])}" @@ -821,8 +823,7 @@ def __lt__(self, other): return len(self.itintervals) < len(other.itintervals) def __hash__(self): - return hash((super().__hash__(), self.sub_iterators, - self.directions)) + return hash((super().__hash__(), self.sub_iterators, self.directions)) def __contains__(self, d): try: From 87ac4011d65cc389bc0e2e9009ccd7abcf50a709 Mon Sep 17 00:00:00 2001 From: Fabio Luporini Date: Thu, 14 Aug 2025 10:05:47 +0100 Subject: [PATCH 16/22] misc: Fix typo --- devito/passes/clusters/derivatives.py | 2 +- devito/passes/iet/definitions.py | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/devito/passes/clusters/derivatives.py b/devito/passes/clusters/derivatives.py index fe6b5ed44b..194d26523d 100644 --- a/devito/passes/clusters/derivatives.py +++ b/devito/passes/clusters/derivatives.py @@ -166,7 +166,7 @@ def _(expr, c, ispace, weights, reusables, mapper, **kwargs): ispace1 = IterationSpace.union(ispace, ispace0, relations=extra) # The Symbol that will hold the result of the IndexDerivative computation - # NOTE: created before recurring so that we ultimately get a sound ordering + # NOTE: created before recursing so that we ultimately get a sound ordering try: s = reusables.pop() assert np.can_cast(s.dtype, w.dtype) diff --git a/devito/passes/iet/definitions.py b/devito/passes/iet/definitions.py index 102372d618..2f1cce8f10 100644 --- a/devito/passes/iet/definitions.py +++ b/devito/passes/iet/definitions.py @@ -5,7 +5,6 @@ from collections import OrderedDict from ctypes import c_uint64 -from functools import singledispatch from operator import itemgetter import numpy as np From 4dbb6d1e4697564f5c46cedd4a1269b91769ce85 Mon Sep 17 00:00:00 2001 From: Fabio Luporini Date: Thu, 14 Aug 2025 13:38:36 +0100 Subject: [PATCH 17/22] compiler: Pass kwargs to make_parallel --- devito/passes/iet/parpragma.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/devito/passes/iet/parpragma.py b/devito/passes/iet/parpragma.py index ec02a5e3cb..eb6916622e 100644 --- a/devito/passes/iet/parpragma.py +++ b/devito/passes/iet/parpragma.py @@ -432,7 +432,7 @@ def _make_parallel(self, iet, sync_mapper=None): return iet, {'includes': [self.langbb['header']]} - def make_parallel(self, graph): + def make_parallel(self, graph, **kwargs): return self._make_parallel(graph, sync_mapper=graph.sync_mapper) From 8542126175cafbe3b6c9c274931d1182156220c5 Mon Sep 17 00:00:00 2001 From: Fabio Luporini Date: Wed, 26 Nov 2025 17:41:19 +0000 Subject: [PATCH 18/22] compiler: Tweak pairwise_or --- devito/ir/support/guards.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/devito/ir/support/guards.py b/devito/ir/support/guards.py index dbb1fe9ca7..4b334545a6 100644 --- a/devito/ir/support/guards.py +++ b/devito/ir/support/guards.py @@ -500,7 +500,9 @@ def pairwise_or(*guards): # Analysis for guard in guards: - if guard is true or guard is None: + if guard is true: + return true + elif guard is None: continue elif isinstance(guard, And): components = guard.args From abdda0716ecf1b04c550dcad9b9bb2bf0f4bab64 Mon Sep 17 00:00:00 2001 From: Fabio Luporini Date: Thu, 4 Dec 2025 15:19:30 +0000 Subject: [PATCH 19/22] compiler: Enhance ListInitializer --- devito/symbolics/extended_sympy.py | 8 ++++++-- tests/test_symbolics.py | 25 +++++++++++++++++++++++++ 2 files changed, 31 insertions(+), 2 deletions(-) diff --git a/devito/symbolics/extended_sympy.py b/devito/symbolics/extended_sympy.py index 452cab7309..ea8179d190 100644 --- a/devito/symbolics/extended_sympy.py +++ b/devito/symbolics/extended_sympy.py @@ -334,10 +334,14 @@ class ListInitializer(sympy.Expr, Pickable): Symbolic representation of the C++ list initializer notation ``{a, b, ...}``. """ - __rargs__ = ('params',) + __rargs__ = ('*params',) __rkwargs__ = ('dtype',) - def __new__(cls, params, dtype=None): + def __new__(cls, *params, dtype=None, evaluate=False): + # Legacy API: allow a single list/tuple as argument + if len(params) == 1 and isinstance(params[0], (list, tuple, np.ndarray)): + params = params[0] + args = [] for p in as_tuple(params): try: diff --git a/tests/test_symbolics.py b/tests/test_symbolics.py index 85d2818940..6aedd92cf5 100644 --- a/tests/test_symbolics.py +++ b/tests/test_symbolics.py @@ -428,6 +428,31 @@ def test_namespace(): assert not ns0.free_symbols +def test_list_initializer(): + # Legacy interface + init0 = ListInitializer((1, 2, 3)) + assert str(init0) == '{1, 2, 3}' + + init1 = ListInitializer(1, 2, 3) + assert str(init1) == '{1, 2, 3}' + + # Test hashing and equality + assert init0 == init1 + assert hash(init0) == hash(init1) + init2 = ListInitializer(1, 2) + assert init0 != init2 + assert hash(init0) != hash(init2) + assert hash(init0) == hash(init1) + + # Reconstruction + assert init0 == init0._rebuild() + assert init1 == init1._rebuild() + assert str(init1._rebuild(4, 5)) == '{4, 5}' + + # Accept `evaluate` but gently ignore it + assert str(ListInitializer((1, 2), evaluate=True)) == '{1, 2}' + + def test_rvalue(): ctype = ReservedWord('dummytype') ns = Namespace(['my', 'namespace']) From 6dcff845903d7a1f8dc11b582ca40067487b6fd8 Mon Sep 17 00:00:00 2001 From: Fabio Luporini Date: Wed, 21 Jan 2026 15:55:01 +0000 Subject: [PATCH 20/22] compiler: Bump SafeInv cost --- devito/symbolics/inspection.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/devito/symbolics/inspection.py b/devito/symbolics/inspection.py index 3dd0bc55bb..8ecd5ab4fa 100644 --- a/devito/symbolics/inspection.py +++ b/devito/symbolics/inspection.py @@ -116,8 +116,8 @@ def estimate_cost(exprs, estimate=False): estimate_values = { 'elementary': 100, + 'SafeInv': 75, 'pow': 50, - 'SafeInv': 50, 'div': 5, 'Abs': 5, 'floor': 1, From 489005b9b252ed35ba00533afef2d18607968629 Mon Sep 17 00:00:00 2001 From: Fabio Luporini Date: Tue, 27 Jan 2026 10:56:52 +0000 Subject: [PATCH 21/22] compiler: Fix Cluster.used_dimensions --- devito/ir/clusters/cluster.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/devito/ir/clusters/cluster.py b/devito/ir/clusters/cluster.py index 6c65f4b97a..37ac0259e1 100644 --- a/devito/ir/clusters/cluster.py +++ b/devito/ir/clusters/cluster.py @@ -179,8 +179,14 @@ def used_dimensions(self): example, reduction or redundant (i.e., invariant) Dimensions won't appear in an expression. """ - idims = set.union(*[set(e.implicit_dims) for e in self.exprs]) - return {i for i in self.free_symbols if i.is_Dimension} | idims + dims_exprs = {i for i in self.free_symbols if i.is_Dimension} + + dims_implicit = set().union(*[set(e.implicit_dims) for e in self.exprs]) + + syms_guards = set().union(*[e.free_symbols for e in self.guards.values()]) + dims_guards = {i for i in syms_guards if i.is_Dimension} + + return dims_exprs | dims_implicit | dims_guards @cached_property def dist_dimensions(self): From 206c29c0272ae19a077f82953147e4f3cf87dfca Mon Sep 17 00:00:00 2001 From: Fabio Luporini Date: Tue, 3 Feb 2026 09:49:36 +0000 Subject: [PATCH 22/22] compiler: Split up Cluster.used_dimensions to fix Lift --- devito/ir/clusters/cluster.py | 29 +++++++++++++++++++++-------- devito/passes/clusters/misc.py | 14 +++++++------- 2 files changed, 28 insertions(+), 15 deletions(-) diff --git a/devito/ir/clusters/cluster.py b/devito/ir/clusters/cluster.py index 37ac0259e1..48cf7c9a68 100644 --- a/devito/ir/clusters/cluster.py +++ b/devito/ir/clusters/cluster.py @@ -172,21 +172,34 @@ def dimensions(self): return set().union(*[i._defines for i in self.ispace.dimensions]) @cached_property - def used_dimensions(self): + def exprs_dimensions(self): """ - The Dimensions that *actually* appear among the expressions in ``self``. - These do not necessarily coincide the IterationSpace Dimensions; for - example, reduction or redundant (i.e., invariant) Dimensions won't - appear in an expression. + The Dimensions that appear explicitly in the Cluster expressions. """ - dims_exprs = {i for i in self.free_symbols if i.is_Dimension} - + dims_explicit = {i for i in self.free_symbols if i.is_Dimension} dims_implicit = set().union(*[set(e.implicit_dims) for e in self.exprs]) + return dims_explicit | dims_implicit + @cached_property + def guards_dimensions(self): + """ + The Dimensions that appear explicitly in the Cluster guards. + """ syms_guards = set().union(*[e.free_symbols for e in self.guards.values()]) dims_guards = {i for i in syms_guards if i.is_Dimension} + return dims_guards - return dims_exprs | dims_implicit | dims_guards + @cached_property + def used_dimensions(self): + """ + All the Dimensions that appear explicitly either within the expressions + or the guards. + + Note that, in some cases, some of the IterationSpace Dimensions might + not appear here among the used Dimensions -- for example, reduction or + redundant (i.e., invariant) Dimensions. + """ + return self.exprs_dimensions | self.guards_dimensions @cached_property def dist_dimensions(self): diff --git a/devito/passes/clusters/misc.py b/devito/passes/clusters/misc.py index 0324f03eae..f248061252 100644 --- a/devito/passes/clusters/misc.py +++ b/devito/passes/clusters/misc.py @@ -55,7 +55,7 @@ def callback(self, clusters, prefix): continue # Is `c` a real candidate -- is there at least one invariant Dimension? - if any(d._defines & hope_invariant for d in c.used_dimensions): + if any(d._defines & hope_invariant for d in c.exprs_dimensions): processed.append(c) continue @@ -69,16 +69,16 @@ def callback(self, clusters, prefix): # All of the inner Dimensions must appear in the write-to region # otherwise we would violate data dependencies. Consider # - # 1) 2) 3) - # for i for i for i - # for x for x for x - # r = f(a[x]) for y for y - # r[x] = f(a[x, y]) r[x, y] = f(a[x, y]) + # 1) 2) 3) + # for i for i for i + # for x for x for x + # r = f(a[x]) for y for y + # r[x] = f(a[x, y]) r[x, y] = f(a[x, y]) # # In 1) and 2) lifting is infeasible; in 3) the statement can # be lifted outside the `i` loop as `r`'s write-to region contains # both `x` and `y` - xed = {d._defines for d in c.used_dimensions if d not in outer} + xed = {d._defines for d in c.exprs_dimensions if d not in outer} if not all(i & set(w.dimensions) for i, w in product(xed, c.scope.writes)): processed.append(c) continue