From aaa36f79663a59fa7bf214bdf074b5cf728a949a Mon Sep 17 00:00:00 2001 From: Fabio Luporini Date: Tue, 24 Mar 2026 10:37:12 +0000 Subject: [PATCH 01/45] compiler: Augment caching and memoization --- devito/finite_differences/derivative.py | 5 +- devito/finite_differences/differentiable.py | 4 + .../finite_differences/finite_difference.py | 5 +- devito/finite_differences/tools.py | 6 +- devito/ir/clusters/cluster.py | 7 +- devito/ir/equations/equation.py | 95 ++++++++++++++++--- devito/ir/iet/nodes.py | 4 +- devito/ir/support/basic.py | 41 ++------ devito/ir/support/utils.py | 69 +------------- devito/operator/operator.py | 17 ++-- devito/symbolics/search.py | 38 +++++++- devito/tools/memoization.py | 27 +++++- devito/types/caching.py | 6 +- tests/test_ir.py | 12 +++ tests/test_symbolics.py | 12 +++ tests/test_tools.py | 1 - 16 files changed, 210 insertions(+), 139 deletions(-) diff --git a/devito/finite_differences/derivative.py b/devito/finite_differences/derivative.py index 24bbdea972..08feb17839 100644 --- a/devito/finite_differences/derivative.py +++ b/devito/finite_differences/derivative.py @@ -6,7 +6,9 @@ import sympy -from devito.tools import Pickable, as_mapper, as_tuple, frozendict, is_integer +from devito.tools import ( + Pickable, as_mapper, as_tuple, frozendict, is_integer, memoized_func +) from devito.types.dimension import Dimension from devito.types.utils import DimensionTuple from devito.warnings import warn @@ -557,6 +559,7 @@ def _evaluate(self, **kwargs): def _eval_deriv(self): return self._eval_fd(self.expr) + @memoized_func(scope='build') def _eval_fd(self, expr, **kwargs): """ Evaluate the finite-difference approximation of the Derivative. diff --git a/devito/finite_differences/differentiable.py b/devito/finite_differences/differentiable.py index 6322b8ca8f..fed3baecc3 100644 --- a/devito/finite_differences/differentiable.py +++ b/devito/finite_differences/differentiable.py @@ -1028,6 +1028,10 @@ def compare(self, other): def base(self): return self.expr.func(*[a for a in self.expr.args if a is not self.weights]) + @cached_property + def pivot(self): + return self.base.subs({d: 0 for d in self.dimensions}) + @property def weights(self): return self._weights diff --git a/devito/finite_differences/finite_difference.py b/devito/finite_differences/finite_difference.py index 30199fb3d8..bdf3199b0d 100644 --- a/devito/finite_differences/finite_difference.py +++ b/devito/finite_differences/finite_difference.py @@ -170,14 +170,15 @@ def make_derivative(expr, dim, fd_order, deriv_order, side, matvec, x0, coeffici # `coefficients` method (`taylor` or `symbolic`) if weights is None: weights = fd_weights_registry[coefficients](expr, deriv_order, indices, x0) - if isinstance(weights, Iterable) and len(weights) != len(indices): + _, wdim, _ = process_weights(weights, expr, dim) + elif isinstance(weights, Iterable) and len(weights) != len(indices): warning(f"Number of weights ({len(weights)}) does not match " f"number of indices ({len(indices)}), reverting to Taylor") scale = False + wdim = None weights = fd_weights_registry['taylor'](expr, deriv_order, indices, x0) # Did fd_weights_registry return a new Function/Expression instead of a values? - _, wdim, _ = process_weights(weights, expr, dim) if wdim is not None: weights = [weights._subs(wdim, i) for i in range(len(indices))] diff --git a/devito/finite_differences/tools.py b/devito/finite_differences/tools.py index 438c4da0c9..8c9304b126 100644 --- a/devito/finite_differences/tools.py +++ b/devito/finite_differences/tools.py @@ -228,10 +228,14 @@ def make_stencil_dimension(expr, _min, _max): @cacheit -def numeric_weights(function, deriv_order, indices, x0): +def _numeric_weights(deriv_order, indices, x0): return finite_diff_weights(deriv_order, indices, x0)[-1][-1] +def numeric_weights(function, deriv_order, indices, x0): + return _numeric_weights(deriv_order, indices, x0) + + fd_weights_registry = {'taylor': numeric_weights, 'standard': numeric_weights, 'symbolic': numeric_weights} # Backward compat for 'symbolic' coeff_priority = {'taylor': 1, 'standard': 1} diff --git a/devito/ir/clusters/cluster.py b/devito/ir/clusters/cluster.py index cbf206b3ff..cbced4e606 100644 --- a/devito/ir/clusters/cluster.py +++ b/devito/ir/clusters/cluster.py @@ -8,8 +8,8 @@ from devito.ir.support import ( PARALLEL, PARALLEL_IF_PVT, BaseGuardBoundNext, DataSpace, Forward, Guards, Interval, IntervalGroup, IterationSpace, PrefetchUpdate, Properties, Scope, WaitLock, WithLock, - detect_accesses, detect_io, maximum, minimum, normalize_properties, normalize_syncs, - null_ispace, tailor_properties, update_properties + detect_accesses, maximum, minimum, normalize_properties, normalize_syncs, null_ispace, + tailor_properties, update_properties ) from devito.mpi.halo_scheme import HaloScheme, HaloTouch from devito.mpi.reduction_scheme import DistReduce @@ -491,7 +491,8 @@ def traffic(self): ----- If a Function is both read and written, then it is counted twice. """ - reads, writes = detect_io(self.exprs, relax=True) + reads = flatten(i.read_functions_relaxed for i in self.exprs) + writes = flatten(i.write_functions_relaxed for i in self.exprs) accesses = [(i, 'r') for i in reads] + [(i, 'w') for i in writes] # Ordering isn't important at this point, so returning an unordered diff --git a/devito/ir/equations/equation.py b/devito/ir/equations/equation.py index 8d72704b79..e8221a2364 100644 --- a/devito/ir/equations/equation.py +++ b/devito/ir/equations/equation.py @@ -1,3 +1,4 @@ +from contextlib import suppress from functools import cached_property import numpy as np @@ -6,11 +7,10 @@ from devito.finite_differences.differentiable import diff2sympy from devito.ir.equations.algorithms import dimension_sort, lower_exprs from devito.ir.support import ( - GuardFactor, Interval, IntervalGroup, IterationSpace, Stencil, detect_accesses, - detect_io + GuardFactor, Interval, IntervalGroup, IterationSpace, Stencil, detect_accesses ) -from devito.symbolics import IntDiv, limits_mapper, uxreplace -from devito.tools import Pickable, Tag, frozendict +from devito.symbolics import IntDiv, limits_mapper, retrieve_accesses, uxreplace +from devito.tools import Pickable, Tag, filter_sorted, frozendict from devito.types import Eq, Inc, ReduceMax, ReduceMin, ReduceMinMax, relational_min __all__ = [ @@ -80,6 +80,82 @@ def is_Reduction(self): def is_Increment(self): return self.operation is OpInc + @cached_property + def _writes(self): + from devito.symbolics.queries import q_routine + + terminals = set(retrieve_accesses(self.lhs)) + if q_routine(self.rhs): + with suppress(AttributeError): + # Everything except: foreign routines, such as `cos` or `sin` etc. + terminals.update(self.rhs.writes) + + return tuple(terminals) + + @property + def writes(self): + return self._writes + + @cached_property + def reads_explicit(self): + terminals = set(retrieve_accesses(self.rhs, deep=True)) + with suppress(AttributeError): + terminals.update(retrieve_accesses(self.lhs.indices)) + + return tuple(terminals) + + @cached_property + def reads_conditional(self): + accesses = [] + for v in self.conditionals.values(): + accesses.extend(retrieve_accesses(v)) + + return tuple(accesses) + + @cached_property + def _reads(self): + return tuple(set(self.reads_explicit) | set(self.reads_conditional)) + + @property + def reads(self): + return self._reads + + @cached_property + def _read_functions(self): + found = [] + for i in self.reads: + with suppress(AttributeError): + i = i.function + found.append(i) + return tuple(filter_sorted(found)) + + @cached_property + def _write_functions(self): + found = [] + for i in self.writes: + with suppress(AttributeError): + i = i.function + found.append(i) + return tuple(filter_sorted(found)) + + @cached_property + def read_functions(self): + return tuple(i for i in self._read_functions if i.is_Input) + + @cached_property + def write_functions(self): + return tuple(i for i in self._write_functions if i.is_Input) + + @cached_property + def read_functions_relaxed(self): + return tuple(i for i in self._read_functions + if i.is_Input or i.is_AbstractFunction) + + @cached_property + def write_functions_relaxed(self): + return tuple(i for i in self._write_functions + if i.is_Input or i.is_AbstractFunction) + def apply(self, func): """ Apply a callable to `self` and each expr-like attribute carried by `self`, @@ -175,7 +251,7 @@ class LoweredEq(IREq): `LoweredEq.__rkwargs__` must appear in `kwargs`. """ - __rkwargs__ = IREq.__rkwargs__ + ('reads', 'writes') + __rkwargs__ = IREq.__rkwargs__ def __new__(cls, *args, **kwargs): if len(args) == 1 and isinstance(args[0], LoweredEq): @@ -250,20 +326,11 @@ def __new__(cls, *args, **kwargs): expr._ispace = ispace expr._conditionals = conditionals - expr._reads, expr._writes = detect_io(expr) expr._implicit_dims = input_expr.implicit_dims expr._operation = Operation.detect(input_expr) return expr - @property - def reads(self): - return self._reads - - @property - def writes(self): - return self._writes - def xreplace(self, rules): return LoweredEq(self.lhs.xreplace(rules), self.rhs.xreplace(rules), **self.state) diff --git a/devito/ir/iet/nodes.py b/devito/ir/iet/nodes.py index d106b5e811..d26b97af7a 100644 --- a/devito/ir/iet/nodes.py +++ b/devito/ir/iet/nodes.py @@ -16,7 +16,7 @@ from devito.ir.equations import DummyEq, OpInc, OpMax, OpMin, OpMinMax from devito.ir.support import ( AFFINE, INBOUND, PARALLEL, PARALLEL_IF_ATOMIC, PARALLEL_IF_PVT, SEQUENTIAL, - VECTORIZED, Forward, PrefetchUpdate, Property, WithLock, detect_io + VECTORIZED, Forward, PrefetchUpdate, Property, WithLock ) from devito.symbolics import CallFromPointer, ListInitializer from devito.tools import ( @@ -452,7 +452,7 @@ def rhs(self): @cached_property def reads(self): """The Functions read by the Expression.""" - return detect_io(self.expr, relax=True)[0] + return self.expr.read_functions_relaxed @cached_property def write(self): diff --git a/devito/ir/support/basic.py b/devito/ir/support/basic.py index 7939ee8fe8..ba26370dda 100644 --- a/devito/ir/support/basic.py +++ b/devito/ir/support/basic.py @@ -10,8 +10,8 @@ from devito.ir.support.utils import AccessMode, extrema from devito.ir.support.vector import LabeledVector, Vector from devito.symbolics import ( - compare_ops, q_affine, q_comp_acc, q_constant, q_routine, retrieve_indexed, - retrieve_terminals, search, uxreplace + compare_ops, q_affine, q_comp_acc, q_constant, retrieve_accesses, + retrieve_indexed ) from devito.tools import ( CacheInstances, Tag, as_mapper, as_tuple, filter_sorted, flatten, is_integer, @@ -320,6 +320,7 @@ def lex_le(self, other): def lex_lt(self, other): return self.timestamp < other.timestamp + @memoized_meth def distance(self, other, logical=False): """ Compute the distance from ``self`` to ``other``. @@ -876,13 +877,7 @@ def writes_gen(self): Generate all write accesses. """ for i, e in enumerate(self.exprs): - terminals = retrieve_accesses(e.lhs) - if q_routine(e.rhs): - with suppress(AttributeError): - # Everything except: foreign routines, such as `cos` or `sin` etc. - terminals.update(e.rhs.writes) - - for j in terminals: + for j in e.writes: mode = 'WR' if e.is_Reduction else 'W' yield TimedAccess(j, mode, i, e.ispace) @@ -919,11 +914,7 @@ def reads_explicit_gen(self): expressions. """ for i, e in enumerate(self.exprs): - # Reads - terminals = retrieve_accesses(e.rhs, deep=True) - with suppress(AttributeError): - terminals.update(retrieve_accesses(e.lhs.indices)) - for j in terminals: + for j in e.reads_explicit: mode = 'RR' if j.function is e.lhs.function and e.is_Reduction else 'R' yield TimedAccess(j, mode, i, e.ispace) @@ -932,9 +923,8 @@ def reads_explicit_gen(self): yield TimedAccess(e.lhs, 'RR', i, e.ispace) # Look up ConditionalDimensions - for v in e.conditionals.values(): - for j in retrieve_accesses(v): - yield TimedAccess(j, 'R', -1, e.ispace) + for j in e.reads_conditional: + yield TimedAccess(j, 'R', -1, e.ispace) @memoized_generator def reads_implicit_gen(self): @@ -1381,23 +1371,6 @@ def vinf(entries): return Vector(*(entries + [S.Infinity])) -def retrieve_accesses(exprs, **kwargs): - """ - Like retrieve_terminals, but ensure that if a ComponentAccess is found, - the ComponentAccess itself is returned, while the wrapped Indexed is discarded. - """ - kwargs['mode'] = 'unique' - - compaccs = search(exprs, ComponentAccess) - if not compaccs: - return retrieve_terminals(exprs, **kwargs) - - subs = {i: Symbol(f'dummy{n}') for n, i in enumerate(compaccs)} - exprs1 = uxreplace(exprs, subs) - - return compaccs | retrieve_terminals(exprs1, **kwargs) - set(subs.values()) - - def disjoint_test(e0, e1, d, it): """ A rudimentary test to check if two accesses `e0` and `e1` along `d` within diff --git a/devito/ir/support/utils.py b/devito/ir/support/utils.py index 644bab5d4c..5f2ee39af7 100644 --- a/devito/ir/support/utils.py +++ b/devito/ir/support/utils.py @@ -3,8 +3,8 @@ from itertools import product from devito.finite_differences import IndexDerivative -from devito.symbolics import CallFromPointer, retrieve_indexed, retrieve_terminals, search -from devito.tools import DefaultOrderedDict, as_tuple, filter_sorted, flatten, split +from devito.symbolics import retrieve_indexed, search +from devito.tools import DefaultOrderedDict, as_tuple, filter_sorted, split from devito.types import ( Dimension, DimensionTuple, Indirection, ModuloDimension, StencilDimension, TensorMove ) @@ -14,7 +14,6 @@ 'IMask', 'Stencil', 'detect_accesses', - 'detect_io', 'erange', 'extrema', 'maximum', @@ -217,70 +216,6 @@ def detect_accesses(exprs): return mapper -def detect_io(exprs, relax=False): - """ - ``{exprs} -> ({reads}, {writes})`` - - Parameters - ---------- - exprs : expr-like or list of expr-like - The searched expressions. - relax : bool, optional - If False, as by default, collect all Input objects, such as - Constants and Functions. Otherwise, also collect AbstractFunctions. - """ - exprs = as_tuple(exprs) - if relax is False: - rule = lambda i: i.is_Input - else: - rule = lambda i: i.is_Input or i.is_AbstractFunction - - # Don't forget the nasty case with indirections on the LHS: - # >>> u[t, a[x]] = f[x] -> (reads={a, f}, writes={u}) - - roots = [] - for i in exprs: - try: - roots.append(i.rhs) - roots.extend(list(i.lhs.indices)) - roots.extend(list(i.conditionals.values())) - except AttributeError: - # E.g., CallFromPointer - roots.append(i) - - reads = [] - terminals = flatten(retrieve_terminals(i, deep=True) for i in roots) - for i in terminals: - candidates = set(i.free_symbols) - with suppress(AttributeError): - candidates.update({i.function}) - for j in candidates: - try: - if rule(j): - reads.append(j) - except AttributeError: - pass - - writes = [] - for i in exprs: - try: - f = i.lhs.function - except AttributeError: - continue - try: - if rule(f): - writes.append(f) - except AttributeError: - # We only end up here after complex IET transformations which make - # use of composite types - assert isinstance(i.lhs, CallFromPointer) - f = i.lhs.base.function - if rule(f): - writes.append(f) - - return tuple(filter_sorted(reads)), tuple(filter_sorted(writes)) - - def pull_dims(exprs, flag=True): """ Extract all Dimensions from one or more expressions. If `flag=True` diff --git a/devito/operator/operator.py b/devito/operator/operator.py index 6e77c39281..66e9857132 100644 --- a/devito/operator/operator.py +++ b/devito/operator/operator.py @@ -38,8 +38,8 @@ from devito.symbolics import estimate_cost, subs_op_args from devito.tools import ( DAG, CacheInstances, MemoryEstimate, OrderedSet, ReducerMap, Signer, as_mapper, - as_tuple, contains_val, filter_sorted, flatten, frozendict, is_integer, split, - timed_pass, timed_region + as_tuple, contains_val, filter_sorted, flatten, frozendict, is_integer, + memoized_func, split, timed_pass, timed_region ) from devito.types import Buffer, Evaluable, device_layer, disk_layer, host_layer from devito.types.dimension import Thickness @@ -184,7 +184,11 @@ def __new__(cls, expressions, **kwargs): # Lower to a JIT-compilable object with timed_region('op-compile') as r: - op = cls._build(expressions, **kwargs) + try: + op = cls._build(expressions, **kwargs) + finally: + CacheInstances.clear_caches() + memoized_func.clear_build_caches() op._profiler.py_timers.update(r.timings) # Emit info about how long it took to perform the lowering @@ -261,15 +265,12 @@ def _build(cls, expressions, **kwargs): op._state = cls._initialize_state(**kwargs) # Produced by the various compilation passes - op._reads = filter_sorted(flatten(e.reads for e in irs.expressions)) - op._writes = filter_sorted(flatten(e.writes for e in irs.expressions)) + op._reads = filter_sorted(flatten(e.read_functions for e in irs.expressions)) + op._writes = filter_sorted(flatten(e.write_functions for e in irs.expressions)) op._dimensions = set().union(*[e.dimensions for e in irs.expressions]) op._dtype, op._dspace = irs.clusters.meta op._profiler = profiler - # Clear build-scoped instance caches - CacheInstances.clear_caches() - return op def __init__(self, *args, **kwargs): diff --git a/devito/symbolics/search.py b/devito/symbolics/search.py index 9c30948064..55064cbc23 100644 --- a/devito/symbolics/search.py +++ b/devito/symbolics/search.py @@ -8,13 +8,14 @@ from devito.symbolics.queries import ( q_derivative, q_dimension, q_function, q_indexed, q_leaf, q_symbol, q_terminal ) -from devito.tools import as_tuple +from devito.tools import as_tuple, memoized_func __all__ = [ 'retrieve_derivatives', 'retrieve_dimensions', 'retrieve_function_carriers', 'retrieve_functions', + 'retrieve_accesses', 'retrieve_indexed', 'retrieve_symbols', 'retrieve_terminals', @@ -140,10 +141,19 @@ def retrieve_indexed(exprs, mode='all', deep=False): def retrieve_functions(exprs, mode='all', deep=False): """Shorthand to retrieve the DiscreteFunctions in `exprs`.""" - indexeds = search(exprs, q_indexed, mode, 'dfs', deep) + query = lambda i: q_function(i) or q_indexed(i) + found = search(exprs, query, 'all', 'dfs', deep) + + functions = modes[mode]() + indexed_functions = set() + + for i in found: + if q_function(i): + functions.add(i) if mode == 'unique' else functions.append(i) + else: + indexed_functions.add(i.function) - functions = search(exprs, q_function, mode, 'dfs', deep) - functions.update({i.function for i in indexeds}) + functions.update(indexed_functions) return functions @@ -177,6 +187,26 @@ def retrieve_terminals(exprs, mode='all', deep=False): return search(exprs, q_terminal, mode, 'dfs', deep) +@memoized_func(scope='build') +def retrieve_accesses(exprs, deep=False): + """ + Like retrieve_terminals, but ensure that if a ComponentAccess is found, + the ComponentAccess itself is returned, while the wrapped Indexed is discarded. + """ + from devito.symbolics.manipulation import uxreplace + from devito.types import ComponentAccess, Symbol + + compaccs = search(exprs, ComponentAccess) + if not compaccs: + return frozenset(retrieve_terminals(exprs, mode='unique', deep=deep)) + + subs = {i: Symbol(f'dummy{n}') for n, i in enumerate(compaccs)} + exprs1 = uxreplace(exprs, subs) + + return frozenset(compaccs | retrieve_terminals(exprs1, mode='unique', deep=deep) - + set(subs.values())) + + def retrieve_dimensions(exprs, mode='all', deep=False): """Shorthand to retrieve the dimensions in ``exprs``.""" return search(exprs, q_dimension, mode, 'dfs', deep) diff --git a/devito/tools/memoization.py b/devito/tools/memoization.py index c10f5ea092..bb0de4de30 100644 --- a/devito/tools/memoization.py +++ b/devito/tools/memoization.py @@ -19,9 +19,22 @@ class memoized_func: https://wiki.python.org/moin/PythonDecoratorLibrary#Memoize """ - def __init__(self, func): + # Long-lived caches for process-global helpers, such as arch discovery. + _scope_persistent = 'persistent' + # Build-scoped caches that may retain compiler inputs during Operator construction. + _scope_build = 'build' + _scoped_caches = {} + + def __new__(cls, func=None, *, scope=None): + if func is None: + return lambda f: cls(f, scope=scope) + return super().__new__(cls) + + def __init__(self, func, *, scope=None): self.func = func + self.scope = scope or self._scope_persistent self.cache = {} + self._scoped_caches.setdefault(self.scope, set()).add(self) def __call__(self, *args, **kw): if not isinstance(args, Hashable): @@ -44,6 +57,18 @@ def __get__(self, obj, objtype): """Support instance methods.""" return partial(self.__call__, obj) + def clear(self): + self.cache.clear() + + @classmethod + def clear_scoped_caches(cls, scope): + for cache in cls._scoped_caches.get(scope, ()): + cache.clear() + + @classmethod + def clear_build_caches(cls): + cls.clear_scoped_caches(cls._scope_build) + class memoized_meth: """ diff --git a/devito/types/caching.py b/devito/types/caching.py index 742c0b3d33..9fbbcaa638 100644 --- a/devito/types/caching.py +++ b/devito/types/caching.py @@ -4,7 +4,7 @@ import sympy from sympy.core import cache -from devito.tools import safe_dict_copy +from devito.tools import memoized_func, safe_dict_copy __all__ = ['CacheManager', 'Cached', 'Uncached', '_SymbolCache'] @@ -175,6 +175,10 @@ def clear(cls, force=True): # SymPy 1.14 and later pass + # Drop compiler-scoped Python memoization that may still hold strong + # references to symbolic objects pending collection. + memoized_func.clear_build_caches() + # Take a copy of the dictionary so we can safely iterate over it # even if another thread is making changes cache_copied = safe_dict_copy(_SymbolCache) diff --git a/tests/test_ir.py b/tests/test_ir.py index 16440ec54a..bb79e6b212 100644 --- a/tests/test_ir.py +++ b/tests/test_ir.py @@ -788,6 +788,18 @@ def test_indirect_access(self): v = scope.d_flow.pop() assert v.function is s1 + def test_ireq_function_views_indirect_indices(self): + grid = Grid(shape=(4,)) + x, = grid.dimensions + + u = Function(name='u', grid=grid) + f = Function(name='f', grid=grid) + a = Function(name='a', grid=grid) + + expr = LoweredEq(Eq(u, f[a[x]])) + + assert set(expr.read_functions) == {f, a} + def test_array_shared(self): grid = Grid(shape=(4, 4)) x, y = grid.dimensions diff --git a/tests/test_symbolics.py b/tests/test_symbolics.py index aab8502a20..5e6a4150e4 100644 --- a/tests/test_symbolics.py +++ b/tests/test_symbolics.py @@ -848,6 +848,18 @@ def test_is_on_grid(): assert all(uu._grid_map == {} for uu in retrieve_functions(u.subs({x: x0}).evaluate)) +def test_retrieve_functions_mixed_carriers(): + grid = Grid((10,)) + x = grid.dimensions[0] + + f = Function(name='f', grid=grid) + g = Function(name='g', grid=grid) + + expr = f + FIndexed(g.base, x) + + assert retrieve_functions(expr, mode='unique') == {f, g} + + @pytest.mark.parametrize('expr,expected', [ ('f[x+2]*g[x+4] + f[x+3]*g[x+5] + f[x+4] + f[x+1]', ['f[x+2]', 'g[x+4]', 'f[x+3]', 'g[x+5]', 'f[x+1]', 'f[x+4]']), diff --git a/tests/test_tools.py b/tests/test_tools.py index 0b06883e78..89856d920b 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -212,7 +212,6 @@ def __init__(self, value: int): cache_size = Object._instance_cache.cache_info()[-1] assert cache_size == 0 - def test_switchenv(): # Save previous environment previous_environ = dict(os.environ) From c4864ca5bb2a264338eb2d24990381c8043c7c3b Mon Sep 17 00:00:00 2001 From: Fabio Luporini Date: Thu, 2 Apr 2026 11:46:58 +0100 Subject: [PATCH 02/45] compiler: Augment caching and tweak memoization heuristics --- devito/ir/clusters/cluster.py | 27 ++++++++++--- devito/ir/clusters/visitors.py | 4 ++ devito/ir/support/properties.py | 5 +++ devito/ir/support/space.py | 70 ++++++++++++++++++++++++--------- devito/tools/memoization.py | 5 ++- tests/test_ir.py | 57 ++++++++++++++++++++++++++- tests/test_tools.py | 25 ++++++++++++ 7 files changed, 166 insertions(+), 27 deletions(-) diff --git a/devito/ir/clusters/cluster.py b/devito/ir/clusters/cluster.py index cbced4e606..546a8901f5 100644 --- a/devito/ir/clusters/cluster.py +++ b/devito/ir/clusters/cluster.py @@ -121,12 +121,27 @@ def rebuild(self, *args, **kwargs): raise ValueError("`exprs` provided both as arg and kwarg") kwargs['exprs'] = args[0] - return self.__class__(exprs=kwargs.get('exprs', self.exprs), - ispace=kwargs.get('ispace', self.ispace), - guards=kwargs.get('guards', self.guards), - properties=kwargs.get('properties', self.properties), - syncs=kwargs.get('syncs', self.syncs), - halo_scheme=kwargs.get('halo_scheme', self.halo_scheme)) + exprs = kwargs.get('exprs', self.exprs) + ispace = kwargs.get('ispace', self.ispace) + guards = kwargs.get('guards', self.guards) + properties = kwargs.get('properties', self.properties) + syncs = kwargs.get('syncs', self.syncs) + halo_scheme = kwargs.get('halo_scheme', self.halo_scheme) + + if exprs is self.exprs and \ + ispace is self.ispace and \ + guards is self.guards and \ + properties is self.properties and \ + syncs is self.syncs and \ + halo_scheme is self.halo_scheme: + return self + + return self.__class__(exprs=exprs, + ispace=ispace, + guards=guards, + properties=properties, + syncs=syncs, + halo_scheme=halo_scheme) @property def exprs(self): diff --git a/devito/ir/clusters/visitors.py b/devito/ir/clusters/visitors.py index 11bcad5365..a5b6587a9f 100644 --- a/devito/ir/clusters/visitors.py +++ b/devito/ir/clusters/visitors.py @@ -113,6 +113,10 @@ def _process_fatd(self, clusters, level, prefix=None, **kwargs): class Prefix(IterationSpace): + @classmethod + def _preprocess_args(cls, ispace, guards, properties, syncs): + return (ispace, guards, properties, syncs), {} + def __init__(self, ispace, guards, properties, syncs): super().__init__(ispace.intervals, ispace.sub_iterators, ispace.directions) diff --git a/devito/ir/support/properties.py b/devito/ir/support/properties.py index 9e787a8b9e..8dc759cc73 100644 --- a/devito/ir/support/properties.py +++ b/devito/ir/support/properties.py @@ -199,6 +199,11 @@ class Properties(frozendict): A mapper {Dimension -> {properties}}. """ + def __init__(self, *args, **kwargs): + mapper = dict(*args, **kwargs) + mapper = {d: frozenset(as_tuple(v)) for d, v in mapper.items()} + super().__init__(mapper) + @property def dimensions(self): return tuple(self) diff --git a/devito/ir/support/space.py b/devito/ir/support/space.py index 1c760128b5..a1ef7e50f8 100644 --- a/devito/ir/support/space.py +++ b/devito/ir/support/space.py @@ -8,8 +8,8 @@ from devito.ir.support.utils import maximum, minimum from devito.ir.support.vector import Vector, vmax, vmin from devito.tools import ( - Ordering, Stamp, as_list, as_set, as_tuple, filter_ordered, flatten, frozendict, - is_integer, toposort + CacheInstances, Ordering, Stamp, as_list, as_set, as_tuple, filter_ordered, + flatten, frozendict, is_integer, toposort ) from devito.types import Dimension, ModuloDimension @@ -88,7 +88,7 @@ def negate(self): translate = negate -class NullInterval(AbstractInterval): +class NullInterval(AbstractInterval, CacheInstances): """ A degenerate iterated closed interval on Z. @@ -96,6 +96,10 @@ class NullInterval(AbstractInterval): is_Null = True + @classmethod + def _preprocess_args(cls, dim, stamp=S0): + return (dim, stamp), {} + def __repr__(self): return f"{self.dim}[Null]{self.stamp}" @@ -120,7 +124,7 @@ def switch(self, d): return NullInterval(d, self.stamp) -class Interval(AbstractInterval): +class Interval(AbstractInterval, CacheInstances): """ Interval(dim, lower, upper) @@ -134,6 +138,18 @@ class Interval(AbstractInterval): is_Defined = True + @classmethod + def _preprocess_args(cls, dim, lower=0, upper=0, stamp=S0): + try: + lower = int(lower) + except TypeError: + assert isinstance(lower, Expr) + try: + upper = int(upper) + except TypeError: + assert isinstance(upper, Expr) + return (dim, lower, upper, stamp), {} + def __init__(self, dim, lower=0, upper=0, stamp=S0): super().__init__(dim, stamp) @@ -304,12 +320,18 @@ def expand(self): ) -class IntervalGroup(Ordering): +class IntervalGroup(Ordering, CacheInstances): """ A sequence of Intervals equipped with set-like operations. """ + @classmethod + def _preprocess_args(cls, items=None, relations=None, mode='total'): + items = as_tuple(items) + relations = tuple(tuple(i) for i in as_tuple(relations)) + return (items,), {'relations': relations, 'mode': mode} + @classmethod def reorder(cls, items, relations): if not all(isinstance(i, AbstractInterval) for i in items): @@ -335,13 +357,13 @@ def simplify_relations(cls, relations, items, mode): return super().simplify_relations(relations, items, mode) def __eq__(self, o): - return len(self) == len(o) and all(i == j for i, j in zip(self, o, strict=True)) + return isinstance(o, IntervalGroup) and super().__eq__(o) def __contains__(self, d): return any(i.dim is d for i in self) def __hash__(self): - return hash(tuple(self)) + return hash((tuple(self), self.relations, self.mode)) def __repr__(self): return "IntervalGroup[{}]".format(', '.join([repr(i) for i in self])) @@ -618,6 +640,11 @@ class IterationInterval(Interval): An Interval associated with metadata. """ + @classmethod + def _preprocess_args(cls, interval, sub_iterators=(), direction=Forward): + sub_iterators = tuple(filter_ordered(as_tuple(sub_iterators))) + return (interval, sub_iterators, direction), {} + def __init__(self, interval, sub_iterators=(), direction=Forward): super().__init__(interval.dim, *interval.offsets, stamp=interval.stamp) self.sub_iterators = sub_iterators @@ -768,8 +795,7 @@ def reset(self): return DataSpace(intervals, parts) - -class IterationSpace(Space): +class IterationSpace(Space, CacheInstances): """ Represent an iteration space as a Space with additional metadata and operations. @@ -785,23 +811,29 @@ class IterationSpace(Space): A mapper from Dimensions in ``intervals`` to IterationDirections. """ - def __init__(self, intervals, sub_iterators=None, directions=None): - super().__init__(intervals) + @classmethod + def _preprocess_args(cls, intervals, sub_iterators=None, directions=None): + if not isinstance(intervals, IntervalGroup): + intervals = IntervalGroup(as_tuple(intervals)) - # Normalize sub-iterators sub_iterators = sub_iterators or {} sub_iterators = {d: tuple(filter_ordered(as_tuple(v))) - for d, v in sub_iterators.items() if d in self.intervals} - sub_iterators.update({i.dim: () for i in self.intervals + for d, v in sub_iterators.items() if d in intervals} + sub_iterators.update({i.dim: () for i in intervals if i.dim not in sub_iterators}) - self._sub_iterators = frozendict(sub_iterators) - # Normalize directions directions = directions or {} - directions = {d: v for d, v in directions.items() if d in self.intervals} - directions.update({i.dim: Any for i in self.intervals + directions = {d: v for d, v in directions.items() if d in intervals} + directions.update({i.dim: Any for i in intervals if i.dim not in directions}) - self._directions = frozendict(directions) + + return (intervals, frozendict(sub_iterators), frozendict(directions)), {} + + def __init__(self, intervals, sub_iterators=None, directions=None): + super().__init__(intervals) + + self._sub_iterators = sub_iterators + self._directions = directions def __repr__(self): ret = ', '.join([f"{repr(i)}{repr(self.directions[i.dim])}" diff --git a/devito/tools/memoization.py b/devito/tools/memoization.py index bb0de4de30..e92101763a 100644 --- a/devito/tools/memoization.py +++ b/devito/tools/memoization.py @@ -179,6 +179,9 @@ def __init__(cls: type[InstanceType], *args) -> None: # type: ignore def __call__(cls: type[InstanceType], # type: ignore *args, **kwargs) -> InstanceType: + if cls._instance_cache_size == 0: + return super().__call__(*args, **kwargs) + args, kwargs = cls._preprocess_args(*args, **kwargs) return cls._instance_cache(*args, **kwargs) @@ -198,7 +201,7 @@ class CacheInstances(metaclass=CacheInstancesMeta): """ _instance_cache: Callable | None = None - _instance_cache_size: int = 128 + _instance_cache_size: int = 8192 @classmethod def _preprocess_args(cls, *args, **kwargs): diff --git a/tests/test_ir.py b/tests/test_ir.py index bb79e6b212..70e2de4af4 100644 --- a/tests/test_ir.py +++ b/tests/test_ir.py @@ -17,7 +17,8 @@ ) from devito.ir.support.guards import GuardOverflow from devito.ir.support.space import ( - Backward, Forward, Interval, IntervalGroup, IterationSpace, NullInterval + Backward, Forward, Interval, IntervalGroup, IterationInterval, IterationSpace, + NullInterval ) from devito.symbolics import DefFunction, FieldFromPointer from devito.tools import prod @@ -359,6 +360,60 @@ def x(self, grid): def y(self, grid): return grid.dimensions[1] + def test_null_interval_cache_identity(self, x): + i0 = NullInterval(x) + i1 = NullInterval(x) + + assert i0 is i1 + + def test_interval_cache_identity(self, x): + i0 = Interval(x, -2, 2) + i1 = Interval(x, -2, 2) + + assert i0 is i1 + + def test_iteration_interval_cache_identity(self, x): + xi = SubDimension.middle('xi', x, 1, 1) + + i0 = IterationInterval(Interval(x), (xi,), Forward) + i1 = IterationInterval(Interval(x), (xi,), Forward) + + assert i0 is i1 + + def test_iteration_interval_cache_distinguishes_sub_iterators(self, x): + xi = SubDimension.middle('xi', x, 1, 1) + + i0 = IterationInterval(Interval(x), (xi,), Forward) + i1 = IterationInterval(Interval(x), (), Forward) + + assert i0 is not i1 + + def test_interval_group_cache_identity(self, x, y): + ig0 = IntervalGroup([Interval(x, -2, 2), Interval(y, -1, 1)], + relations=((x, y),), mode='partial') + ig1 = IntervalGroup((Interval(x, -2, 2), Interval(y, -1, 1)), + relations=((x, y),), mode='partial') + + assert ig0 is ig1 + + def test_iteration_space_cache_identity(self, x): + xi = SubDimension.middle('xi', x, 1, 1) + + ispace0 = IterationSpace([Interval(x)], {x: (xi,)}, {x: Forward}) + ispace1 = IterationSpace([Interval(x)], {x: (xi,)}, {x: Forward}) + + assert ispace0 is ispace1 + assert isinstance(ispace0[x], IterationInterval) + assert ispace0[x] is ispace1[x] + + def test_iteration_space_cache_distinguishes_sub_iterators(self, x): + xi = SubDimension.middle('xi', x, 1, 1) + + ispace0 = IterationSpace([Interval(x)], {x: (xi,)}, {x: Forward}) + ispace1 = IterationSpace([Interval(x)], directions={x: Forward}) + + assert ispace0 is not ispace1 + def test_intervals_intersection(self, x, y): nullx = NullInterval(x) diff --git a/tests/test_tools.py b/tests/test_tools.py index 89856d920b..bef55a82c6 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -212,6 +212,31 @@ def __init__(self, value: int): cache_size = Object._instance_cache.cache_info()[-1] assert cache_size == 0 + def test_uncached_subclass_bypasses_parent_preprocess(self): + """ + Tests that an uncached subclass does not inherit its parent's + preprocessing contract. + """ + class Parent(CacheInstances): + @classmethod + def _preprocess_args(cls, value): + return (value + 1,), {} + + def __init__(self, value: int): + self.value = value + + class Child(Parent): + _instance_cache_size = 0 + + def __init__(self, left: int, right: int): + self.value = (left, right) + + obj0 = Child(1, 2) + obj1 = Child(1, 2) + + assert obj0.value == (1, 2) + assert obj0 is not obj1 + def test_switchenv(): # Save previous environment previous_environ = dict(os.environ) From 77efae6c498553f41c214a667b5eaef4fcb86ab4 Mon Sep 17 00:00:00 2001 From: Fabio Luporini Date: Fri, 3 Apr 2026 13:05:45 +0100 Subject: [PATCH 03/45] compiler: cache TimedAccess instances --- devito/ir/support/basic.py | 8 +++++++- tests/test_ir.py | 8 +++++++- 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/devito/ir/support/basic.py b/devito/ir/support/basic.py index ba26370dda..0bb051e868 100644 --- a/devito/ir/support/basic.py +++ b/devito/ir/support/basic.py @@ -200,7 +200,7 @@ def is_scalar(self): return self.rank == 0 -class TimedAccess(IterationInstance, AccessMode): +class TimedAccess(IterationInstance, AccessMode, CacheInstances): """ A TimedAccess ties together an IterationInstance and an AccessMode. @@ -218,6 +218,12 @@ class TimedAccess(IterationInstance, AccessMode): on the values of the index functions and the access mode (read, write). """ + @classmethod + def _preprocess_args(cls, access, mode, timestamp, ispace=None): + if ispace is None: + ispace = null_ispace + return (access, mode, timestamp, ispace), {} + def __new__(cls, access, mode, timestamp, ispace=None): obj = super().__new__(cls, access) AccessMode.__init__(obj, mode=mode) diff --git a/tests/test_ir.py b/tests/test_ir.py index 70e2de4af4..5ee01273ff 100644 --- a/tests/test_ir.py +++ b/tests/test_ir.py @@ -18,7 +18,7 @@ from devito.ir.support.guards import GuardOverflow from devito.ir.support.space import ( Backward, Forward, Interval, IntervalGroup, IterationInterval, IterationSpace, - NullInterval + NullInterval, null_ispace ) from devito.symbolics import DefFunction, FieldFromPointer from devito.tools import prod @@ -141,6 +141,12 @@ def test_vector_cmp(self, v_num, v_literal): assert v2 <= vs3 assert vs3 > v2 + def test_timedaccess_cached(self, fc, x, y): + ta0 = TimedAccess(fc[x, y], 'R', 0) + ta1 = TimedAccess(fc[x, y], 'R', 0, null_ispace) + + assert ta0 is ta1 + def test_iteration_instance_arithmetic(self, x, y, ii_num, ii_literal): """ Test arithmetic operations involving objects of type IterationInstance. From d41e2b3541d1d04352c4216c85c8b237af6e98d5 Mon Sep 17 00:00:00 2001 From: Fabio Luporini Date: Tue, 7 Apr 2026 12:24:18 +0100 Subject: [PATCH 04/45] compiler: Add heuristics to improve fusion lowering turnaround --- devito/ir/support/basic.py | 116 ++++++++- devito/passes/clusters/__init__.py | 1 + devito/passes/clusters/derivatives.py | 2 +- devito/passes/clusters/fusion.py | 340 ++++++++++++++++++++++++++ devito/passes/clusters/misc.py | 310 +---------------------- 5 files changed, 453 insertions(+), 316 deletions(-) create mode 100644 devito/passes/clusters/fusion.py diff --git a/devito/ir/support/basic.py b/devito/ir/support/basic.py index 0bb051e868..b127b308e5 100644 --- a/devito/ir/support/basic.py +++ b/devito/ir/support/basic.py @@ -326,6 +326,13 @@ def lex_le(self, other): def lex_lt(self, other): return self.timestamp < other.timestamp + @memoized_meth + def shifted(self, offset): + if offset == 0: + return self + + return TimedAccess(self.access, self.mode, self.timestamp + offset, self.ispace) + @memoized_meth def distance(self, other, logical=False): """ @@ -860,18 +867,58 @@ class Scope(CacheInstances): # Describes a rule for dependencies Rule = Callable[[TimedAccess, TimedAccess], bool] + @classmethod + def from_scopes(cls, scope0, scope1, rules=None): + """ + Build a Scope out of two existing Scopes by reusing their cached accesses. + """ + offset = len(scope0.exprs) + + targets = ( + scope0.write_targets & scope1.functions + ) | ( + scope1.write_targets & scope0.functions + ) + + reads = {} + writes = {} + + for f in targets: + shifted = tuple(i.shifted(offset) for i in scope1.getreads(f)) + accesses = scope0.getreads(f) + if shifted: + accesses = accesses + shifted if accesses else shifted + if accesses: + reads[f] = accesses + + shifted = tuple(i.shifted(offset) for i in scope1.getwrites(f)) + accesses = scope0.getwrites(f) + if shifted: + accesses = accesses + shifted if accesses else shifted + if accesses: + writes[f] = accesses + + return cls((), rules=rules, reads=reads.items(), writes=writes.items()) + @classmethod def _preprocess_args(cls, exprs: Expr | Iterable[Expr], **kwargs) -> tuple[tuple, dict]: + for i in ('reads', 'writes'): + with suppress(KeyError): + kwargs[i] = tuple(kwargs[i]) + return (as_tuple(exprs),), kwargs def __init__(self, exprs: tuple[Expr], - rules: Rule | tuple[Rule] | None = None) -> None: + rules: Rule | tuple[Rule] | None = None, + reads=None, writes=None) -> None: """ A Scope enables data dependence analysis on a totally ordered sequence of expressions. """ self.exprs = exprs + self._reads = dict(reads) if reads is not None else None + self._writes = dict(writes) if writes is not None else None # A set of rules to drive the collection of dependencies self.rules: tuple[Scope.Rule] = as_tuple(rules) # type: ignore[assignment] @@ -910,6 +957,9 @@ def writes(self): """ Create a mapper from functions to write accesses. """ + if self._writes is not None: + return self._writes + return as_mapper(self.writes_gen(), key=lambda i: i.function) @memoized_generator @@ -1004,23 +1054,29 @@ def reads_smart_gen(self, f): the iteration symbols. """ if isinstance(f, (Function, Temp, TempArray, TBArray)): - for i in chain(self.reads_explicit_gen(), self.reads_synchro_gen()): - if f is i.function: - for j in extrema(i.access): - yield TimedAccess(j, i.mode, i.timestamp, i.ispace) + for i in self.getreads(f): + for j in extrema(i.access): + yield TimedAccess(j, i.mode, i.timestamp, i.ispace) else: - for i in self.reads_gen(): - if f is i.function: - yield i + for i in self.getreads(f): + yield i @cached_property def reads(self): """ Create a mapper from functions to read accesses. """ + if self._reads is not None: + return self._reads + return as_mapper(self.reads_gen(), key=lambda i: i.function) + @cached_property + def read_targets(self): + """The Functions read within the Scope.""" + return frozenset(self.reads) + @cached_property def read_only(self): """ @@ -1028,6 +1084,16 @@ def read_only(self): """ return set(self.reads) - set(self.writes) + @cached_property + def write_targets(self): + """The Functions written within the Scope.""" + return frozenset(self.writes) + + @cached_property + def has_barrier(self): + """True if the Scope contains a fence-like control-flow object.""" + return any(isinstance(e.rhs, (Fence, CriticalRegion)) for e in self.exprs) + @cached_property def initialized(self): return frozenset(e.lhs.function for e in self.exprs @@ -1083,6 +1149,23 @@ def indexeds(self): def functions(self): return set(self.reads) | set(self.writes) + @memoized_meth + def may_interact(self, other, has_barrier=False): + """ + True if the Scope may induce cross-scope ordering constraints. + + This is a cheap pre-check used to avoid full dependence analysis when + two scopes do not touch any common Function through a write and no + fence-like object lies between them. + """ + if has_barrier or self.has_barrier or other.has_barrier: + return True + + if self.write_targets & other.functions: + return True + + return bool(other.write_targets & self.functions) + @memoized_meth def a_query(self, timestamps=None, modes=None): timestamps = as_tuple(timestamps) @@ -1094,8 +1177,9 @@ def a_query(self, timestamps=None, modes=None): def d_flow_gen(self): """Generate the flow (or "read-after-write") dependences.""" for k, v in self.writes.items(): + reads = tuple(self.reads_smart_gen(k)) for w in v: - for r in self.reads_smart_gen(k): + for r in reads: if any(not rule(w, r) for rule in self.rules): continue @@ -1125,8 +1209,9 @@ def d_flow(self): def d_anti_gen(self, depcls=Dependence): """Generate the anti (or "write-after-read") dependences.""" for k, v in self.writes.items(): + reads = tuple(self.reads_smart_gen(k)) for w in v: - for r in self.reads_smart_gen(k): + for r in reads: if any(not rule(r, w) for rule in self.rules): continue @@ -1377,6 +1462,17 @@ def vinf(entries): return Vector(*(entries + [S.Infinity])) +def _cause_from_distance(findices, distance): + for i, j in zip(findices, distance, strict=False): + try: + if j > 0: + return i._defines + except TypeError: + return i._defines + + return frozenset() + + def disjoint_test(e0, e1, d, it): """ A rudimentary test to check if two accesses `e0` and `e1` along `d` within diff --git a/devito/passes/clusters/__init__.py b/devito/passes/clusters/__init__.py index c41a628e06..e27d2b755d 100644 --- a/devito/passes/clusters/__init__.py +++ b/devito/passes/clusters/__init__.py @@ -3,6 +3,7 @@ from .cse import * # noqa from .aliases import * # noqa from .factorization import * # noqa +from .fusion import * # noqa from .blocking import * # noqa from .asynchrony import * # noqa from .implicit import * # noqa diff --git a/devito/passes/clusters/derivatives.py b/devito/passes/clusters/derivatives.py index 194d26523d..fcbdbd5b8a 100644 --- a/devito/passes/clusters/derivatives.py +++ b/devito/passes/clusters/derivatives.py @@ -5,7 +5,7 @@ from devito.finite_differences import IndexDerivative, Weights from devito.ir import Backward, Forward, Interval, IterationSpace, Queue -from devito.passes.clusters.misc import fuse +from devito.passes.clusters.fusion import fuse from devito.symbolics import BasicWrapperMixin, reuse_if_untouched, uxreplace from devito.tools import infer_dtype, timed_pass from devito.types import Eq, Inc, Indexed, Symbol diff --git a/devito/passes/clusters/fusion.py b/devito/passes/clusters/fusion.py new file mode 100644 index 0000000000..60b5977d6d --- /dev/null +++ b/devito/passes/clusters/fusion.py @@ -0,0 +1,340 @@ +from collections import Counter, defaultdict +from itertools import groupby + +from devito.finite_differences import IndexDerivative +from devito.ir.clusters import Cluster, ClusterGroup, Queue +from devito.ir.support import ( + InitArray, PrefetchUpdate, ReleaseLock, Scope, SyncArray, WaitLock, WithLock +) +from devito.symbolics import search +from devito.tools import ( + DAG, as_tuple, flatten, frozendict, memoized_func, timed_pass +) + +__all__ = ['fuse'] + + +def _is_prefix_carried(dependence, prefix): + return bool(dependence.cause & prefix) + + +@memoized_func(scope='build') +def _fusion_hazards(scope, prefix): + anti = False + for i in scope.d_anti_gen(): + if _is_prefix_carried(i, prefix): + return True, True + anti = True + + if anti: + return False, True + + for i in scope.d_flow_gen(): + if not _is_prefix_carried(i, prefix): + return False, True + + for _ in scope.d_output_gen(): + return False, True + + return False, False + + +class Fusion(Queue): + + """ + Fuse Clusters with compatible IterationSpace. + """ + + _q_guards_in_key = True + _q_syncs_in_key = True + + def __init__(self, toposort, options=None): + options = options or {} + + self.toposort = toposort + self.fusetasks = options.get('fuse-tasks', False) + + super().__init__() + + def process(self, clusters): + cgroups = [ClusterGroup(c, c.ispace) for c in clusters] + cgroups = self._process_fdta(cgroups, 1) + clusters = ClusterGroup.concatenate(*cgroups) + return clusters + + def callback(self, cgroups, prefix): + # Toposort to maximize fusion + if self.toposort: + clusters = self._toposort(cgroups, prefix) + if self.toposort == 'nofuse': + return [clusters] + else: + clusters = ClusterGroup(cgroups) + + # Fusion + processed = [] + for _, group in groupby(clusters, key=self._key): + g = list(group) + + for maybe_fusible in self._apply_heuristics(g): + try: + # Perform fusion + processed.append(Cluster.from_clusters(*maybe_fusible)) + except ValueError: + # We end up here if, for example, some Clusters have same + # iteration Dimensions but different (partial) orderings + processed.extend(maybe_fusible) + + # Maximize effectiveness of topo-sorting at next stage by only + # grouping together Clusters characterized by data dependencies + if self.toposort and prefix: + dag = self._build_dag(processed, prefix) + mapper = dag.connected_components(enumerated=True) + groups = groupby(processed, key=mapper.get) + return [ClusterGroup(tuple(g), prefix) for _, g in groups] + else: + return [ClusterGroup(processed, prefix)] + + class Key(tuple): + + """ + A fusion Key for a Cluster (ClusterGroup) is a hashable tuple such that + two Clusters (ClusterGroups) are topo-fusible if and only if their Key is + identical. + + A Key contains elements that can logically be split into two groups -- the + `strict` and the `weak` components of the Key. Two Clusters (ClusterGroups) + having same `strict` but different `weak` parts are, by definition, not + fusible; however, since at least their `strict` parts match, they can at + least be topologically reordered. + """ + + def __new__(cls, itintervals, guards, syncs, weak): + strict = [itintervals, guards, syncs] + obj = super().__new__(cls, strict + weak) + + obj.itintervals = itintervals + obj.guards = guards + obj.syncs = syncs + + obj.strict = tuple(strict) + obj.weak = tuple(weak) + + return obj + + def _key(self, c): + itintervals = frozenset(c.ispace.itintervals) + guards = c.guards if any(c.guards) else None + + # We allow fusing Clusters/ClusterGroups even in presence of WaitLocks and + # WithLocks, but not with any other SyncOps + mapper = defaultdict(set) + for d, v in c.syncs.items(): + for s in v: + if isinstance(s, PrefetchUpdate): + continue + elif isinstance(s, WaitLock) and not self.fusetasks: + # NOTE: A mix of Clusters w/ and w/o WaitLocks can safely + # be fused, as in the worst case scenario the WaitLocks + # get "hoisted" above the first Cluster in the sequence + continue + elif isinstance(s, (InitArray, SyncArray, WaitLock, ReleaseLock)): + mapper[d].add(type(s)) + elif isinstance(s, WithLock) and self.fusetasks: + # NOTE: Different WithLocks aren't fused unless the user + # explicitly asks for it + mapper[d].add(type(s)) + else: + mapper[d].add(s) + if d in mapper: + mapper[d] = frozenset(mapper[d]) + syncs = frozendict(mapper) + + # Clusters representing HaloTouches should get merged, if possible + weak = [c.is_halo_touch] + + # If there are writes to thread-shared object, make it part of the key. + # This will promote fusion of non-adjacent Clusters writing to (some + # form of) shared memory, which in turn will minimize the number of + # necessary barriers. Same story for reads from thread-shared objects + weak.extend([ + any(f._mem_shared for f in c.scope.writes), + any(f._mem_shared for f in c.scope.reads) + ]) + weak.append(c.properties.is_core_init()) + + # Prefetchable Clusters should get merged, if possible + weak.append(c.is_glb_load_to_mem_shared) + + # Promoting adjacency of IndexDerivatives will maximize their reuse + weak.append(any(search(c.exprs, IndexDerivative))) + + # Promote adjacency of Clusters with same guard + weak.append(c.guards) + + key = self.Key(itintervals, guards, syncs, weak) + + return key + + def _apply_heuristics(self, clusters): + # We know at this point that `clusters` are potentially fusible since + # they have same `_key`, but should we actually fuse them? In most cases + # yes, but there are exceptions... + + # 1) Consider the following scenario with three Clusters: + # c0[no syncs] + # c1[WaitLock] + # c2[no syncs] + # Then we return two groups [[c0], [c1, c2]] rather than a single group + # [[c0, c1, c2]] because this way c0 can be computed without having to + # wait on a lock for a longer period + processed = [] + + group = [] + flag = False # True -> need to dump before creating a new group + + def dump(): + processed.append(tuple(group)) + group[:] = [] + + for c in clusters: + if any(isinstance(i, WaitLock) for i in flatten(c.syncs.values())): + if flag: + dump() + flag = False + else: + flag = True + group.append(c) + dump() + + # 2) Don't group HaloTouch's + groups, processed = processed, [] + for group in groups: + for flag, minigroup in groupby(group, key=lambda c: c.is_wild): + if flag: + processed.extend([(c,) for c in minigroup]) + else: + processed.append(tuple(minigroup)) + + return processed + + def _toposort(self, cgroups, prefix): + # Are there any ClusterGroups that could potentially be topologically + # reordered? If not, do not waste time + counter = Counter(self._key(cg).strict for cg in cgroups) + if len(counter.most_common()) == 1 or \ + not any(v > 1 for it, v in counter.most_common()): + return ClusterGroup(cgroups, prefix) + + dag = self._build_dag(cgroups, prefix) + + def choose_element(queue, scheduled): + if not scheduled: + return queue.pop() + + k = self._key(scheduled[-1]) + m = {i: self._key(i) for i in queue} + + # Process the `strict` part of the key + candidates = [i for i in queue if m[i].itintervals == k.itintervals] + + compatible = [i for i in candidates if m[i].guards == k.guards] + candidates = compatible or candidates + + compatible = [i for i in candidates if m[i].syncs == k.syncs] + candidates = compatible or candidates + + # Process the `weak` part of the key + for i in range(len(k.weak), -1, -1): + choosable = [e for e in candidates if m[e].weak[:i] == k.weak[:i]] + try: + # Ensure stability + e = min(choosable, key=lambda i: cgroups.index(i)) + except ValueError: + continue + queue.remove(e) + return e + + # Fallback + e = min(queue, key=lambda i: cgroups.index(i)) + queue.remove(e) + return e + + return ClusterGroup(dag.topological_sort(choose_element), prefix) + + def _build_dag(self, cgroups, prefix): + """ + A DAG representing the data dependences across the ClusterGroups within + a given scope. + """ + prefix = frozenset(i.dim for i in as_tuple(prefix)) + + dag = DAG(nodes=cgroups) + scopes = [cg.scope for cg in cgroups] + + barrier_count = [0] + for scope in scopes: + barrier_count.append(barrier_count[-1] + int(scope.has_barrier)) + + for n, (cg0, scope0) in enumerate(zip(cgroups, scopes, strict=True)): + def is_cross(source, sink): + # True if a cross-ClusterGroup dependence, False otherwise + t0 = source.timestamp + t1 = sink.timestamp + v = len(cg0.exprs) + return t0 < v <= t1 or t1 < v <= t0 + + for n1, (cg1, scope1) in enumerate(zip(cgroups[n+1:], scopes[n+1:], + strict=True), start=n+1): + has_barrier = barrier_count[n1 + 1] > barrier_count[n] + if not scope0.may_interact(scope1, has_barrier): + continue + + # Reuse the cached per-ClusterGroup accesses instead of + # rescanning the symbolic expressions for each candidate pair. + scope = Scope.from_scopes(scope0, scope1, rules=is_cross) + anti_prefix, forbids_fusion = _fusion_hazards(scope, prefix) + + # Anti-dependences along `prefix` break the execution flow + # (intuitively, "the loop nests are to be kept separated") + # * All ClusterGroups between `cg0` and `cg1` must precede `cg1` + # * All ClusterGroups after `cg1` cannot precede `cg1` + if anti_prefix: + for cg2 in cgroups[n:n1]: + dag.add_edge(cg2, cg1) + for cg2 in cgroups[n1+1:]: + dag.add_edge(cg1, cg2) + break + elif has_barrier or forbids_fusion: + # Any anti- and iaw-dependences impose that `cg1` follows `cg0` + # and forbid any sort of fusion. Fences have the same effect + dag.add_edge(cg0, cg1) + + return dag + + +@timed_pass() +def fuse(clusters, toposort=False, options=None): + """ + Clusters fusion. + + If `toposort=True`, then the Clusters are reordered to maximize the likelihood + of fusion; the new ordering is computed such that all data dependencies are + honored. + + If `toposort='maximal'`, then `toposort` is performed, iteratively, multiple + times to actually maximize Clusters fusion. Hence, this is more aggressive than + `toposort=True`. + """ + if toposort != 'maximal': + return Fusion(toposort, options).process(clusters) + + nxt = clusters + while True: + nxt = fuse(clusters, toposort='nofuse', options=options) + if all(c0 is c1 for c0, c1 in zip(clusters, nxt, strict=True)): + break + clusters = nxt + clusters = fuse(clusters, toposort=False, options=options) + + return clusters diff --git a/devito/passes/clusters/misc.py b/devito/passes/clusters/misc.py index 494ebe7490..4530f6c28f 100644 --- a/devito/passes/clusters/misc.py +++ b/devito/passes/clusters/misc.py @@ -1,18 +1,15 @@ -from collections import Counter, defaultdict from itertools import groupby, product -from devito.finite_differences import IndexDerivative -from devito.ir.clusters import Cluster, ClusterGroup, Queue, cluster_pass +from devito.ir.clusters import Queue, cluster_pass from devito.ir.support import ( - SEPARABLE, SEQUENTIAL, InitArray, PrefetchUpdate, ReleaseLock, Scope, SyncArray, - WaitLock, WithLock + SEPARABLE, SEQUENTIAL, Scope ) from devito.passes.clusters.utils import in_critical_region -from devito.symbolics import pow_to_mul, search -from devito.tools import DAG, Stamp, as_tuple, flatten, frozendict, timed_pass +from devito.symbolics import pow_to_mul +from devito.tools import Stamp, flatten, frozendict, timed_pass from devito.types import Hyperplane -__all__ = ['Lift', 'fission', 'fuse', 'optimize_hyperplanes', 'optimize_pows'] +__all__ = ['Lift', 'fission', 'optimize_hyperplanes', 'optimize_pows'] class Lift(Queue): @@ -107,303 +104,6 @@ def callback(self, clusters, prefix): return lifted + processed -class Fusion(Queue): - - """ - Fuse Clusters with compatible IterationSpace. - """ - - _q_guards_in_key = True - _q_syncs_in_key = True - - def __init__(self, toposort, options=None): - options = options or {} - - self.toposort = toposort - self.fusetasks = options.get('fuse-tasks', False) - - super().__init__() - - def process(self, clusters): - cgroups = [ClusterGroup(c, c.ispace) for c in clusters] - cgroups = self._process_fdta(cgroups, 1) - clusters = ClusterGroup.concatenate(*cgroups) - return clusters - - def callback(self, cgroups, prefix): - # Toposort to maximize fusion - if self.toposort: - clusters = self._toposort(cgroups, prefix) - if self.toposort == 'nofuse': - return [clusters] - else: - clusters = ClusterGroup(cgroups) - - # Fusion - processed = [] - for _, group in groupby(clusters, key=self._key): - g = list(group) - - for maybe_fusible in self._apply_heuristics(g): - try: - # Perform fusion - processed.append(Cluster.from_clusters(*maybe_fusible)) - except ValueError: - # We end up here if, for example, some Clusters have same - # iteration Dimensions but different (partial) orderings - processed.extend(maybe_fusible) - - # Maximize effectiveness of topo-sorting at next stage by only - # grouping together Clusters characterized by data dependencies - if self.toposort and prefix: - dag = self._build_dag(processed, prefix) - mapper = dag.connected_components(enumerated=True) - groups = groupby(processed, key=mapper.get) - return [ClusterGroup(tuple(g), prefix) for _, g in groups] - else: - return [ClusterGroup(processed, prefix)] - - class Key(tuple): - - """ - A fusion Key for a Cluster (ClusterGroup) is a hashable tuple such that - two Clusters (ClusterGroups) are topo-fusible if and only if their Key is - identical. - - A Key contains elements that can logically be split into two groups -- the - `strict` and the `weak` components of the Key. Two Clusters (ClusterGroups) - having same `strict` but different `weak` parts are, by definition, not - fusible; however, since at least their `strict` parts match, they can at - least be topologically reordered. - """ - - def __new__(cls, itintervals, guards, syncs, weak): - strict = [itintervals, guards, syncs] - obj = super().__new__(cls, strict + weak) - - obj.itintervals = itintervals - obj.guards = guards - obj.syncs = syncs - - obj.strict = tuple(strict) - obj.weak = tuple(weak) - - return obj - - def _key(self, c): - itintervals = frozenset(c.ispace.itintervals) - guards = c.guards if any(c.guards) else None - - # We allow fusing Clusters/ClusterGroups even in presence of WaitLocks and - # WithLocks, but not with any other SyncOps - mapper = defaultdict(set) - for d, v in c.syncs.items(): - for s in v: - if isinstance(s, PrefetchUpdate): - continue - elif isinstance(s, WaitLock) and not self.fusetasks: - # NOTE: A mix of Clusters w/ and w/o WaitLocks can safely - # be fused, as in the worst case scenario the WaitLocks - # get "hoisted" above the first Cluster in the sequence - continue - elif isinstance(s, (InitArray, SyncArray, WaitLock, ReleaseLock)): - mapper[d].add(type(s)) - elif isinstance(s, WithLock) and self.fusetasks: - # NOTE: Different WithLocks aren't fused unless the user - # explicitly asks for it - mapper[d].add(type(s)) - else: - mapper[d].add(s) - if d in mapper: - mapper[d] = frozenset(mapper[d]) - syncs = frozendict(mapper) - - # Clusters representing HaloTouches should get merged, if possible - weak = [c.is_halo_touch] - - # If there are writes to thread-shared object, make it part of the key. - # This will promote fusion of non-adjacent Clusters writing to (some - # form of) shared memory, which in turn will minimize the number of - # necessary barriers. Same story for reads from thread-shared objects - weak.extend([ - any(f._mem_shared for f in c.scope.writes), - any(f._mem_shared for f in c.scope.reads) - ]) - weak.append(c.properties.is_core_init()) - - # Prefetchable Clusters should get merged, if possible - weak.append(c.is_glb_load_to_mem_shared) - - # Promoting adjacency of IndexDerivatives will maximize their reuse - weak.append(any(search(c.exprs, IndexDerivative))) - - # Promote adjacency of Clusters with same guard - weak.append(c.guards) - - key = self.Key(itintervals, guards, syncs, weak) - - return key - - def _apply_heuristics(self, clusters): - # We know at this point that `clusters` are potentially fusible since - # they have same `_key`, but should we actually fuse them? In most cases - # yes, but there are exceptions... - - # 1) Consider the following scenario with three Clusters: - # c0[no syncs] - # c1[WaitLock] - # c2[no syncs] - # Then we return two groups [[c0], [c1, c2]] rather than a single group - # [[c0, c1, c2]] because this way c0 can be computed without having to - # wait on a lock for a longer period - processed = [] - - group = [] - flag = False # True -> need to dump before creating a new group - - def dump(): - processed.append(tuple(group)) - group[:] = [] - - for c in clusters: - if any(isinstance(i, WaitLock) for i in flatten(c.syncs.values())): - if flag: - dump() - flag = False - else: - flag = True - group.append(c) - dump() - - # 2) Don't group HaloTouch's - - groups, processed = processed, [] - for group in groups: - for flag, minigroup in groupby(group, key=lambda c: c.is_wild): - if flag: - processed.extend([(c,) for c in minigroup]) - else: - processed.append(tuple(minigroup)) - - return processed - - def _toposort(self, cgroups, prefix): - # Are there any ClusterGroups that could potentially be topologically - # reordered? If not, do not waste time - counter = Counter(self._key(cg).strict for cg in cgroups) - if len(counter.most_common()) == 1 or \ - not any(v > 1 for it, v in counter.most_common()): - return ClusterGroup(cgroups, prefix) - - dag = self._build_dag(cgroups, prefix) - - def choose_element(queue, scheduled): - if not scheduled: - return queue.pop() - - k = self._key(scheduled[-1]) - m = {i: self._key(i) for i in queue} - - # Process the `strict` part of the key - candidates = [i for i in queue if m[i].itintervals == k.itintervals] - - compatible = [i for i in candidates if m[i].guards == k.guards] - candidates = compatible or candidates - - compatible = [i for i in candidates if m[i].syncs == k.syncs] - candidates = compatible or candidates - - # Process the `weak` part of the key - for i in range(len(k.weak), -1, -1): - choosable = [e for e in candidates if m[e].weak[:i] == k.weak[:i]] - try: - # Ensure stability - e = min(choosable, key=lambda i: cgroups.index(i)) - except ValueError: - continue - queue.remove(e) - return e - - # Fallback - e = min(queue, key=lambda i: cgroups.index(i)) - queue.remove(e) - return e - - return ClusterGroup(dag.topological_sort(choose_element), prefix) - - def _build_dag(self, cgroups, prefix): - """ - A DAG representing the data dependences across the ClusterGroups within - a given scope. - """ - prefix = {i.dim for i in as_tuple(prefix)} - - dag = DAG(nodes=cgroups) - for n, cg0 in enumerate(cgroups): - - def is_cross(source, sink): - # True if a cross-ClusterGroup dependence, False otherwise - t0 = source.timestamp - t1 = sink.timestamp - v = len(cg0.exprs) # noqa: B023 - return t0 < v <= t1 or t1 < v <= t0 - - for n1, cg1 in enumerate(cgroups[n+1:], start=n+1): - - # A Scope to compute all cross-ClusterGroup anti-dependences - scope = Scope(exprs=cg0.exprs + cg1.exprs, rules=is_cross) - - # Anti-dependences along `prefix` break the execution flow - # (intuitively, "the loop nests are to be kept separated") - # * All ClusterGroups between `cg0` and `cg1` must precede `cg1` - # * All ClusterGroups after `cg1` cannot precede `cg1` - if any(i.cause & prefix for i in scope.d_anti_gen()): - for cg2 in cgroups[n:cgroups.index(cg1)]: - dag.add_edge(cg2, cg1) - for cg2 in cgroups[cgroups.index(cg1)+1:]: - dag.add_edge(cg1, cg2) - break - - # Any anti- and iaw-dependences impose that `cg1` follows `cg0` - # and forbid any sort of fusion. Fences have the same effect - elif ( - any(scope.d_anti_gen()) or - any(i.is_iaw for i in scope.d_output_gen()) or - any(c.is_fence for c in flatten(cgroups[n:n1+1])) - ) or any(not (i.cause and i.cause & prefix) for i in scope.d_flow_gen()) \ - or any(scope.d_output_gen()): - dag.add_edge(cg0, cg1) - - return dag - - -@timed_pass() -def fuse(clusters, toposort=False, options=None): - """ - Clusters fusion. - - If `toposort=True`, then the Clusters are reordered to maximize the likelihood - of fusion; the new ordering is computed such that all data dependencies are - honored. - - If `toposort='maximal'`, then `toposort` is performed, iteratively, multiple - times to actually maximize Clusters fusion. Hence, this is more aggressive than - `toposort=True`. - """ - if toposort != 'maximal': - return Fusion(toposort, options).process(clusters) - - nxt = clusters - while True: - nxt = fuse(clusters, toposort='nofuse', options=options) - if all(c0 is c1 for c0, c1 in zip(clusters, nxt, strict=True)): - break - clusters = nxt - clusters = fuse(clusters, toposort=False, options=options) - - return clusters - - @cluster_pass(mode='all') def optimize_pows(cluster, *args): """ From 0541de40ef74f490a74e1f54ca7df235ea0727a7 Mon Sep 17 00:00:00 2001 From: Fabio Luporini Date: Wed, 8 Apr 2026 10:15:31 +0100 Subject: [PATCH 05/45] compiler: Improve Cluster fusion implementation --- devito/ir/support/basic.py | 90 ++++++++++++++------------------ devito/passes/clusters/fusion.py | 64 ++++++++++------------- 2 files changed, 68 insertions(+), 86 deletions(-) diff --git a/devito/ir/support/basic.py b/devito/ir/support/basic.py index b127b308e5..86c6c4e688 100644 --- a/devito/ir/support/basic.py +++ b/devito/ir/support/basic.py @@ -15,7 +15,7 @@ ) from devito.tools import ( CacheInstances, Tag, as_mapper, as_tuple, filter_sorted, flatten, is_integer, - memoized_generator, memoized_meth, smart_gt, smart_lt + memoized_func, memoized_generator, memoized_meth, smart_gt, smart_lt ) from devito.types import ( ComponentAccess, CriticalRegion, Dimension, DimensionTuple, Fence, Function, Symbol, @@ -326,12 +326,19 @@ def lex_le(self, other): def lex_lt(self, other): return self.timestamp < other.timestamp - @memoized_meth - def shifted(self, offset): - if offset == 0: + def rebuild(self, **kwargs): + access = kwargs.get('access', self.access) + mode = kwargs.get('mode', self.mode) + timestamp = kwargs.get('timestamp', self.timestamp) + ispace = kwargs.get('ispace', self.ispace) + + if access is self.access and \ + mode is self.mode and \ + timestamp is self.timestamp and \ + ispace is self.ispace: return self - return TimedAccess(self.access, self.mode, self.timestamp + offset, self.ispace) + return TimedAccess(access, mode, timestamp, ispace) @memoized_meth def distance(self, other, logical=False): @@ -868,37 +875,57 @@ class Scope(CacheInstances): Rule = Callable[[TimedAccess, TimedAccess], bool] @classmethod - def from_scopes(cls, scope0, scope1, rules=None): + @memoized_func(scope='build') + def from_scopes(cls, scope0, scope1): """ - Build a Scope out of two existing Scopes by reusing their cached accesses. + Build a synthetic Scope out of two existing Scopes by reusing their + cached reads and writes rather than rediscovering accesses from the + underlying expressions. + + This is used to analyze cross-scope dependences cheaply, for example in + loop-fusion hazard checks. Return None if the two Scopes cannot induce + any cross-scope dependences. """ offset = len(scope0.exprs) targets = ( - scope0.write_targets & scope1.functions + set(scope0.writes) & scope1.functions ) | ( - scope1.write_targets & scope0.functions + set(scope1.writes) & scope0.functions ) + if not targets: + return None + + def is_cross(source, sink): + t0 = source.timestamp + t1 = sink.timestamp + return t0 < offset <= t1 or t1 < offset <= t0 reads = {} writes = {} for f in targets: - shifted = tuple(i.shifted(offset) for i in scope1.getreads(f)) + shifted = tuple( + i.rebuild(timestamp=i.timestamp + offset) + for i in scope1.getreads(f) + ) accesses = scope0.getreads(f) if shifted: accesses = accesses + shifted if accesses else shifted if accesses: reads[f] = accesses - shifted = tuple(i.shifted(offset) for i in scope1.getwrites(f)) + shifted = tuple( + i.rebuild(timestamp=i.timestamp + offset) + for i in scope1.getwrites(f) + ) accesses = scope0.getwrites(f) if shifted: accesses = accesses + shifted if accesses else shifted if accesses: writes[f] = accesses - return cls((), rules=rules, reads=reads.items(), writes=writes.items()) + return cls((), rules=is_cross, reads=reads.items(), writes=writes.items()) @classmethod def _preprocess_args(cls, exprs: Expr | Iterable[Expr], @@ -1072,11 +1099,6 @@ def reads(self): return as_mapper(self.reads_gen(), key=lambda i: i.function) - @cached_property - def read_targets(self): - """The Functions read within the Scope.""" - return frozenset(self.reads) - @cached_property def read_only(self): """ @@ -1084,11 +1106,6 @@ def read_only(self): """ return set(self.reads) - set(self.writes) - @cached_property - def write_targets(self): - """The Functions written within the Scope.""" - return frozenset(self.writes) - @cached_property def has_barrier(self): """True if the Scope contains a fence-like control-flow object.""" @@ -1149,23 +1166,6 @@ def indexeds(self): def functions(self): return set(self.reads) | set(self.writes) - @memoized_meth - def may_interact(self, other, has_barrier=False): - """ - True if the Scope may induce cross-scope ordering constraints. - - This is a cheap pre-check used to avoid full dependence analysis when - two scopes do not touch any common Function through a write and no - fence-like object lies between them. - """ - if has_barrier or self.has_barrier or other.has_barrier: - return True - - if self.write_targets & other.functions: - return True - - return bool(other.write_targets & self.functions) - @memoized_meth def a_query(self, timestamps=None, modes=None): timestamps = as_tuple(timestamps) @@ -1461,18 +1461,6 @@ def is_regular(self): def vinf(entries): return Vector(*(entries + [S.Infinity])) - -def _cause_from_distance(findices, distance): - for i, j in zip(findices, distance, strict=False): - try: - if j > 0: - return i._defines - except TypeError: - return i._defines - - return frozenset() - - def disjoint_test(e0, e1, d, it): """ A rudimentary test to check if two accesses `e0` and `e1` along `d` within diff --git a/devito/passes/clusters/fusion.py b/devito/passes/clusters/fusion.py index 60b5977d6d..793cddedd8 100644 --- a/devito/passes/clusters/fusion.py +++ b/devito/passes/clusters/fusion.py @@ -14,29 +14,37 @@ __all__ = ['fuse'] -def _is_prefix_carried(dependence, prefix): - return bool(dependence.cause & prefix) +# No hazard: fusion may proceed. +NO_HAZARD = None +# Ordering hazard: preserve program order and forbid fusion. +EDGE = 'edge' +# Prefix anti-dependence: break the execution flow across the pair. +BREAK = 'break' @memoized_func(scope='build') -def _fusion_hazards(scope, prefix): +def _fusion_hazards(scope0, scope1, prefix): + scope = Scope.from_scopes(scope0, scope1) + if scope is None: + return NO_HAZARD + anti = False for i in scope.d_anti_gen(): - if _is_prefix_carried(i, prefix): - return True, True + if i.cause & prefix: + return BREAK anti = True if anti: - return False, True + return EDGE for i in scope.d_flow_gen(): - if not _is_prefix_carried(i, prefix): - return False, True + if not (i.cause & prefix): + return EDGE for _ in scope.d_output_gen(): - return False, True + return EDGE - return False, False + return NO_HAZARD class Fusion(Queue): @@ -270,42 +278,28 @@ def _build_dag(self, cgroups, prefix): prefix = frozenset(i.dim for i in as_tuple(prefix)) dag = DAG(nodes=cgroups) - scopes = [cg.scope for cg in cgroups] - - barrier_count = [0] - for scope in scopes: - barrier_count.append(barrier_count[-1] + int(scope.has_barrier)) - - for n, (cg0, scope0) in enumerate(zip(cgroups, scopes, strict=True)): - def is_cross(source, sink): - # True if a cross-ClusterGroup dependence, False otherwise - t0 = source.timestamp - t1 = sink.timestamp - v = len(cg0.exprs) - return t0 < v <= t1 or t1 < v <= t0 - - for n1, (cg1, scope1) in enumerate(zip(cgroups[n+1:], scopes[n+1:], - strict=True), start=n+1): - has_barrier = barrier_count[n1 + 1] > barrier_count[n] - if not scope0.may_interact(scope1, has_barrier): - continue + for n, cg0 in enumerate(cgroups): + # Track whether there is any fence between `cg0` and the current `cg1`. + fenced = cg0.scope.has_barrier - # Reuse the cached per-ClusterGroup accesses instead of - # rescanning the symbolic expressions for each candidate pair. - scope = Scope.from_scopes(scope0, scope1, rules=is_cross) - anti_prefix, forbids_fusion = _fusion_hazards(scope, prefix) + for n1, cg1 in enumerate(cgroups[n+1:], start=n+1): + fenced = fenced or cg1.scope.has_barrier + + hazard = _fusion_hazards(cg0.scope, cg1.scope, prefix) + if not (hazard or fenced): + continue # Anti-dependences along `prefix` break the execution flow # (intuitively, "the loop nests are to be kept separated") # * All ClusterGroups between `cg0` and `cg1` must precede `cg1` # * All ClusterGroups after `cg1` cannot precede `cg1` - if anti_prefix: + if hazard == BREAK: for cg2 in cgroups[n:n1]: dag.add_edge(cg2, cg1) for cg2 in cgroups[n1+1:]: dag.add_edge(cg1, cg2) break - elif has_barrier or forbids_fusion: + elif fenced or hazard == EDGE: # Any anti- and iaw-dependences impose that `cg1` follows `cg0` # and forbid any sort of fusion. Fences have the same effect dag.add_edge(cg0, cg1) From e123742a784267b7badda70b6df76274c9bd1358 Mon Sep 17 00:00:00 2001 From: Fabio Luporini Date: Wed, 8 Apr 2026 15:13:13 +0100 Subject: [PATCH 06/45] compiler: Avoid rebuilding IET if unnecessary --- devito/ir/iet/visitors.py | 100 +++++++++++++++++++++++++++++--------- tests/test_visitors.py | 10 +++- 2 files changed, 85 insertions(+), 25 deletions(-) diff --git a/devito/ir/iet/visitors.py b/devito/ir/iet/visitors.py index 1eae6433e3..872a9e2602 100644 --- a/devito/ir/iet/visitors.py +++ b/devito/ir/iet/visitors.py @@ -7,6 +7,7 @@ import ctypes from collections import OrderedDict from collections.abc import Callable, Generator, Iterable, Iterator, Sequence +from contextlib import suppress from itertools import chain, groupby from typing import Any, Generic, TypeVar @@ -1332,12 +1333,12 @@ def transform(self, o, handle, **kwargs): else: children = o.children children = (tuple(handle) + children[0],) + tuple(children[1:]) - return o._rebuild(*children, **o.args_frozen) + return reuse_if_unchanged(o, *children, **o.args_frozen) else: # Replace `o` with `handle` if self.nested: children = [self._visit(i, **kwargs) for i in handle.children] - return handle._rebuild(*children, **handle.args_frozen) + return reuse_if_unchanged(handle, *children, **handle.args_frozen) else: return handle @@ -1346,7 +1347,12 @@ def visit_object(self, o, **kwargs): def visit_tuple(self, o, **kwargs): visited = tuple(self._visit(i, **kwargs) for i in o) - return tuple(i for i in visited if i is not None) + processed = tuple(i for i in visited if i is not None) + + if same_as_before(o, processed): + return o + + return processed visit_list = visit_tuple @@ -1357,7 +1363,7 @@ def visit_Node(self, o, **kwargs): children = [self._visit(i, **kwargs) for i in o.children] if o._traversable and not any(children) and any(o.children): return None - return o._rebuild(*children, **o.args_frozen) + return reuse_if_unchanged(o, *children, **o.args_frozen) def visit_Operator(self, o, **kwargs): raise ValueError("Cannot apply a Transformer visitor to an Operator directly") @@ -1375,7 +1381,7 @@ class Uxreplace(Transformer): """ def visit_Expression(self, o): - return o._rebuild(expr=uxreplace(o.expr, self.mapper)) + return reuse_if_unchanged(o, expr=uxreplace(o.expr, self.mapper)) def _visit_Iteration_common(self, o): nodes = self._visit(o.nodes) @@ -1392,8 +1398,8 @@ def visit_Iteration(self, o): nodes, dimension, limits, pragmas, uindices = \ self._visit_Iteration_common(o) - return o._rebuild(nodes=nodes, dimension=dimension, limits=limits, - pragmas=pragmas, uindices=uindices) + return reuse_if_unchanged(o, nodes=nodes, dimension=dimension, limits=limits, + pragmas=pragmas, uindices=uindices) def visit_PragmaIteration(self, o): nodes, dimension, limits, pragmas, uindices = \ @@ -1420,7 +1426,7 @@ def visit_Return(self, o): def visit_Callable(self, o): body = self._visit(o.body) parameters = [self.mapper.get(i, i) for i in o.parameters] - return o._rebuild(body=body, parameters=parameters) + return reuse_if_unchanged(o, body=body, parameters=parameters) def visit_Call(self, o): arguments = [] @@ -1431,47 +1437,47 @@ def visit_Call(self, o): arguments.append(uxreplace(i, self.mapper)) if o.retobj is not None: retobj = uxreplace(o.retobj, self.mapper) - return o._rebuild(arguments=arguments, retobj=retobj) + return reuse_if_unchanged(o, arguments=arguments, retobj=retobj) else: - return o._rebuild(arguments=arguments) + return reuse_if_unchanged(o, arguments=arguments) def visit_Lambda(self, o): body = self._visit(o.body) parameters = [self.mapper.get(i, i) for i in o.parameters] - return o._rebuild(body=body, parameters=parameters) + return reuse_if_unchanged(o, body=body, parameters=parameters) def visit_Conditional(self, o): condition = uxreplace(o.condition, self.mapper) then_body = self._visit(o.then_body) else_body = self._visit(o.else_body) - return o._rebuild(condition=condition, then_body=then_body, - else_body=else_body) + return reuse_if_unchanged(o, condition=condition, then_body=then_body, + else_body=else_body) def visit_Switch(self, o): condition = uxreplace(o.condition, self.mapper) nodes = self._visit(o.nodes) default = self._visit(o.default) - return o._rebuild(condition=condition, nodes=nodes, default=default) + return reuse_if_unchanged(o, condition=condition, nodes=nodes, default=default) def visit_PointerCast(self, o): function = self.mapper.get(o.function, o.function) obj = self.mapper.get(o.obj, o.obj) - return o._rebuild(function=function, obj=obj) + return reuse_if_unchanged(o, function=function, obj=obj) def visit_Dereference(self, o): pointee = self.mapper.get(o.pointee, o.pointee) pointer = self.mapper.get(o.pointer, o.pointer) - return o._rebuild(pointee=pointee, pointer=pointer) + return reuse_if_unchanged(o, pointee=pointee, pointer=pointer) def visit_Pragma(self, o): arguments = [uxreplace(i, self.mapper) for i in o.arguments] - return o._rebuild(arguments=arguments) + return reuse_if_unchanged(o, arguments=arguments) def visit_PragmaTransfer(self, o): function = uxreplace(o.function, self.mapper) arguments = [uxreplace(i, self.mapper) for i in o.arguments] if o.imask is None: - return o._rebuild(function=function, arguments=arguments) + return reuse_if_unchanged(o, function=function, arguments=arguments) # An `imask` may be None, a list of symbols/numbers, or a list of # 2-tuples representing ranges @@ -1483,25 +1489,26 @@ def visit_PragmaTransfer(self, o): uxreplace(j, self.mapper))) except TypeError: imask.append(uxreplace(v, self.mapper)) - return o._rebuild(function=function, imask=imask, arguments=arguments) + return reuse_if_unchanged(o, function=function, imask=imask, + arguments=arguments) def visit_ParallelTree(self, o): prefix = self._visit(o.prefix) body = self._visit(o.body) nthreads = self.mapper.get(o.nthreads, o.nthreads) - return o._rebuild(prefix=prefix, body=body, nthreads=nthreads) + return reuse_if_unchanged(o, prefix=prefix, body=body, nthreads=nthreads) def visit_HaloSpot(self, o): hs = o.halo_scheme fmapper = {self.mapper.get(k, k): v for k, v in hs.fmapper.items()} halo_scheme = hs._rebuild(fmapper=fmapper) body = self._visit(o.body) - return o._rebuild(halo_scheme=halo_scheme, body=body) + return reuse_if_unchanged(o, halo_scheme=halo_scheme, body=body) def visit_While(self, o, **kwargs): condition = uxreplace(o.condition, self.mapper) body = self._visit(o.body) - return o._rebuild(condition=condition, body=body) + return reuse_if_unchanged(o, condition=condition, body=body) visit_ThreadedProdder = visit_Call @@ -1510,8 +1517,8 @@ def visit_KernelLaunch(self, o): grid = self.mapper.get(o.grid, o.grid) block = self.mapper.get(o.block, o.block) stream = self.mapper.get(o.stream, o.stream) - return o._rebuild(grid=grid, block=block, stream=stream, - arguments=arguments) + return reuse_if_unchanged(o, grid=grid, block=block, stream=stream, + arguments=arguments) # Utils @@ -1519,6 +1526,51 @@ def visit_KernelLaunch(self, o): blankline = c.Line("") +def same_as_before(old, new): + if old is new: + return True + + if isinstance(old, (tuple, list)) and isinstance(new, (tuple, list)): + return len(old) == len(new) and all( + same_as_before(i, j) for i, j in zip(old, new, strict=True) + ) + + if type(old) is not type(new): + return False + + if isinstance(old, dict): + return old.keys() == new.keys() and all( + same_as_before(v, new[k]) for k, v in old.items() + ) + + return False + + +def reuse_if_unchanged(o, *children, **kwargs): + def same_kwarg(k, v): + with suppress(AttributeError): + if same_as_before(getattr(o, k), v): + return True + + with suppress(KeyError): + if same_as_before(o.args[k], v): + return True + + return False + + if children: + same_children = all( + same_as_before(i, j) for i, j in zip(o.children, children, strict=True) + ) + if not same_children: + return o._rebuild(*children, **kwargs) + + if kwargs and not all(same_kwarg(k, v) for k, v in kwargs.items()): + return o._rebuild(*children, **kwargs) + + return o + + def printAST(node, verbose=True): return PrintAST(verbose=verbose)._visit(node) diff --git a/tests/test_visitors.py b/tests/test_visitors.py index 06eb933351..7acd7bbf86 100644 --- a/tests/test_visitors.py +++ b/tests/test_visitors.py @@ -7,7 +7,7 @@ from devito.ir.iet import ( Block, Call, Callable, Conditional, Expression, FindApplications, FindNodes, FindSections, FindSymbols, IsPerfectIteration, Iteration, MapNodes, Transformer, - printAST + Uxreplace, printAST ) from devito.types import Array, SpaceDimension, Symbol @@ -249,6 +249,14 @@ def test_transformer_wrap(exprs, block1, block2, block3): assert "a[i] = a[i] + b[i] + 5.0F;" in newcode +def test_transformer_reuses_untouched_node(block1): + assert Transformer({}).visit(block1) is block1 + + +def test_uxreplace_reuses_untouched_node(block1): + assert Uxreplace({}).visit(block1) is block1 + + def test_transformer_replace(exprs, block1, block2, block3): """Basic transformer test that replaces an expression""" line1 = '// Replaced expression' From d72b448c3d1b9078a61902e796f1539f27d96262 Mon Sep 17 00:00:00 2001 From: Fabio Luporini Date: Wed, 8 Apr 2026 15:44:01 +0100 Subject: [PATCH 07/45] compiler: Fix propagation of transitive IET arg updates --- devito/passes/iet/engine.py | 15 +++++++++++---- 1 file changed, 11 insertions(+), 4 deletions(-) diff --git a/devito/passes/iet/engine.py b/devito/passes/iet/engine.py index 9b936fba76..ad744d19c6 100644 --- a/devito/passes/iet/engine.py +++ b/devito/passes/iet/engine.py @@ -135,6 +135,7 @@ def apply(self, func, **kwargs): efuncs = dict(self.efuncs) for i in dag.topological_sort(): efunc, metadata = func(efuncs[i], **kwargs) + new_efuncs = metadata.get('efuncs', []) self.includes.extend(as_tuple(metadata.get('includes'))) self.headers.extend(as_tuple(metadata.get('headers'))) @@ -151,11 +152,9 @@ def apply(self, func, **kwargs): except KeyError: pass - if efunc is efuncs[i]: + if efunc is efuncs[i] and not new_efuncs: continue - new_efuncs = metadata.get('efuncs', []) - efuncs[i] = efunc efuncs.update(dict([(i.name, i) for i in new_efuncs])) @@ -780,6 +779,14 @@ def _filter(v, efunc=None): mapper = {c: c._rebuild(arguments=_filter(c.arguments)) for c in FindNodes(Call).visit(efuncs[n]) if c.name == root.name} - efuncs[n] = Transformer(mapper).visit(efuncs[n]) + if not mapper: + continue + + efunc = Transformer(mapper).visit(efuncs[n]) + if efunc is efuncs[n]: + continue + + efuncs[n] = efunc + efuncs = update_args(efunc, efuncs, dag) return efuncs From a6232b0b52f40d3f6b9369f4aeb252d614477c8a Mon Sep 17 00:00:00 2001 From: Fabio Luporini Date: Wed, 8 Apr 2026 16:39:39 +0100 Subject: [PATCH 08/45] compiler: cache CGen printers by settings --- devito/ir/cgen/printer.py | 31 ++++++++++++++++++++++++++----- devito/ir/iet/visitors.py | 5 +++-- tests/test_dtypes.py | 15 ++++++++++++++- 3 files changed, 43 insertions(+), 8 deletions(-) diff --git a/devito/ir/cgen/printer.py b/devito/ir/cgen/printer.py index 8ef479955e..38acc7b7a5 100644 --- a/devito/ir/cgen/printer.py +++ b/devito/ir/cgen/printer.py @@ -21,7 +21,10 @@ from devito.tools import ctypes_to_cstr, ctypes_vector_mapper, dtype_to_ctype from devito.types.basic import AbstractFunction -__all__ = ['BasePrinter', 'ccode'] +__all__ = ['BasePrinter', 'ccode', 'get_printer'] + +_preset_dtypes = (np.float32, np.float64, np.complex64, np.complex128) +_printer_registry = {} class BasePrinter(CodePrinter): @@ -449,15 +452,33 @@ def _print_Fallback(self, expr): sympy.printing.str.StrPrinter._print_Add = BasePrinter._print_Add -def ccode(expr, printer=None, **settings): +def get_printer(printer, dtype=None): + try: + registry = _printer_registry[printer] + except KeyError: + default = printer() + registry = {None: default, default.dtype: default} + for i in _preset_dtypes: + registry.setdefault(i, printer(settings={'dtype': i})) + _printer_registry[printer] = registry + + try: + return registry[dtype] + except KeyError: + handle = printer(settings={'dtype': dtype}) + registry[dtype] = handle + return handle + + +def ccode(expr, printer=None, dtype=None): """Generate C++ code from an expression. Parameters ---------- expr : expr-like The expression to be printed. - settings : dict - Options for code printing. + dtype : data-type, optional + Data type used by the printer. Returns ------- @@ -468,4 +489,4 @@ def ccode(expr, printer=None, **settings): if printer is None: from devito.passes.iet.languages.C import CPrinter printer = CPrinter - return printer(settings=settings).doprint(expr, None) + return get_printer(printer, dtype).doprint(expr, None) diff --git a/devito/ir/iet/visitors.py b/devito/ir/iet/visitors.py index 872a9e2602..7959ca2186 100644 --- a/devito/ir/iet/visitors.py +++ b/devito/ir/iet/visitors.py @@ -16,6 +16,7 @@ from sympy.core.function import Application from devito.exceptions import CompilationError +from devito.ir.cgen.printer import get_printer from devito.ir.iet.nodes import ( BlankLine, Call, Expression, ExpressionBundle, Iteration, Lambda, ListMajor, Node, Section @@ -256,8 +257,8 @@ def __init__(self, *args, printer=None, **kwargs): printer = CPrinter self.printer = printer - def ccode(self, expr, **kwargs): - return self.printer(settings=kwargs).doprint(expr, None) + def ccode(self, expr, dtype=None): + return get_printer(self.printer, dtype).doprint(expr, None) @property def _qualifiers_mapper(self): diff --git a/tests/test_dtypes.py b/tests/test_dtypes.py index 18d0d3609c..4206f70057 100644 --- a/tests/test_dtypes.py +++ b/tests/test_dtypes.py @@ -10,7 +10,7 @@ from devito import Constant, Eq, Function, Grid, Operator, configuration, exp, log, sin from devito.arch.compiler import CustomCompiler, GNUCompiler from devito.exceptions import InvalidOperator -from devito.ir.cgen.printer import BasePrinter +from devito.ir.cgen.printer import BasePrinter, get_printer from devito.passes.iet.langbase import LangBB from devito.passes.iet.languages.C import CBB, CPrinter from devito.passes.iet.languages.openacc import AccBB, AccPrinter @@ -204,6 +204,19 @@ def test_math_functions(dtype: np.dtype[np.inexact], assert call_str in str(op) +def test_printer_registry() -> None: + default = get_printer(CPrinter) + + assert get_printer(CPrinter) is default + assert get_printer(CPrinter, np.float32) is default + + float64 = get_printer(CPrinter, np.float64) + assert get_printer(CPrinter, np.float64) is float64 + + float16 = get_printer(CPrinter, np.float16) + assert get_printer(CPrinter, np.float16) is float16 + + @pytest.mark.parametrize('dtype', [np.complex64, np.complex128]) def test_complex_override(dtype: np.dtype[np.complexfloating]) -> None: """ From 3466e0dab1d03ca382f53498e3e7cdc0832bc987 Mon Sep 17 00:00:00 2001 From: Fabio Luporini Date: Thu, 9 Apr 2026 11:06:38 +0100 Subject: [PATCH 09/45] compiler: Enhance Scope to improve DDA turnaround --- devito/ir/support/basic.py | 62 +++++++++++++++++++++++++++++--------- 1 file changed, 48 insertions(+), 14 deletions(-) diff --git a/devito/ir/support/basic.py b/devito/ir/support/basic.py index 86c6c4e688..4b188a3130 100644 --- a/devito/ir/support/basic.py +++ b/devito/ir/support/basic.py @@ -1,6 +1,6 @@ from collections.abc import Callable, Iterable from contextlib import suppress -from functools import cached_property +from functools import cached_property, wraps from itertools import chain, product import sympy @@ -874,6 +874,19 @@ class Scope(CacheInstances): # Describes a rule for dependencies Rule = Callable[[TimedAccess, TimedAccess], bool] + def normalize_input(func): + + @wraps(func) + def wrapper(self, *args, writes=None, **kwargs): + mapped = {} + for k in as_tuple(writes or self.writes): + v = self.getwrites(k) + if v: + mapped[k] = v + return func(self, *args, writes=mapped, **kwargs) + + return wrapper + @classmethod @memoized_func(scope='build') def from_scopes(cls, scope0, scope1): @@ -1174,9 +1187,14 @@ def a_query(self, timestamps=None, modes=None): if a.timestamp in timestamps and a.mode in modes) @memoized_generator - def d_flow_gen(self): - """Generate the flow (or "read-after-write") dependences.""" - for k, v in self.writes.items(): + @normalize_input + def d_flow_gen(self, writes=None): + """ + Generate the flow (or "read-after-write") dependences. + + If ``writes`` is provided, restrict the analysis to those Functions. + """ + for k, v in writes.items(): reads = tuple(self.reads_smart_gen(k)) for w in v: for r in reads: @@ -1206,9 +1224,14 @@ def d_flow(self): return DependenceGroup(self.d_flow_gen()) @memoized_generator - def d_anti_gen(self, depcls=Dependence): - """Generate the anti (or "write-after-read") dependences.""" - for k, v in self.writes.items(): + @normalize_input + def d_anti_gen(self, depcls=Dependence, writes=None): + """ + Generate the anti (or "write-after-read") dependences. + + If ``writes`` is provided, restrict the analysis to those Functions. + """ + for k, v in writes.items(): reads = tuple(self.reads_smart_gen(k)) for w in v: for r in reads: @@ -1246,11 +1269,16 @@ def d_anti_logical(self): return DependenceGroup(self.d_anti_gen(depcls=LogicalDependence)) @memoized_generator - def d_output_gen(self): - """Generate the output (or "write-after-write") dependences.""" - for k, v in self.writes.items(): + @normalize_input + def d_output_gen(self, writes=None): + """ + Generate the output (or "write-after-write") dependences. + + If ``writes`` is provided, restrict the analysis to those Functions. + """ + for k, v in writes.items(): for w1 in v: - for w2 in self.writes.get(k, []): + for w2 in v: if any(not rule(w2, w1) for rule in self.rules): continue @@ -1274,9 +1302,15 @@ def d_output(self): """Output (or "write-after-write") dependences.""" return DependenceGroup(self.d_output_gen()) - def d_all_gen(self): - """Generate all flow, anti and output dependences.""" - return chain(self.d_flow_gen(), self.d_anti_gen(), self.d_output_gen()) + def d_all_gen(self, writes=None): + """ + Generate all flow, anti and output dependences. + + If ``writes`` is provided, restrict the analysis to those Functions. + """ + return chain(self.d_flow_gen(writes=writes), + self.d_anti_gen(writes=writes), + self.d_output_gen(writes=writes)) @cached_property def d_all(self): From b929454458872cd5c737bf5b0d2c1913e7b136b8 Mon Sep 17 00:00:00 2001 From: Fabio Luporini Date: Thu, 9 Apr 2026 11:06:57 +0100 Subject: [PATCH 10/45] compiler: Exploit the new Scope API --- devito/ir/clusters/analysis.py | 6 +----- devito/ir/support/basic.py | 11 ++++++----- devito/passes/iet/mpi.py | 23 +++++++++++------------ 3 files changed, 18 insertions(+), 22 deletions(-) diff --git a/devito/ir/clusters/analysis.py b/devito/ir/clusters/analysis.py index 5ebae71b0f..f78f1ee456 100644 --- a/devito/ir/clusters/analysis.py +++ b/devito/ir/clusters/analysis.py @@ -101,7 +101,7 @@ def _callback(self, clusters, dim, prefix): is_parallel_atomic = False scope = Scope(flatten(c.exprs for c in clusters)) - for dep in scope.d_all_gen(): + for dep in scope.d_all_gen(writes=scope.writes_tensor): test00 = dep.is_indep(dim) and not dep.is_storage_related(dim) test01 = all(dep.is_reduce_atmost(i) for i in prev) if test00 and test01: @@ -112,10 +112,6 @@ def _callback(self, clusters, dim, prefix): is_parallel_indep &= (dep.distance_mapper.get(dim.root) == 0) continue - if dep.function in scope.initialized: - # False alarm, the dependence is over a locally-defined symbol - continue - if dep.is_reduction: is_parallel_atomic = True continue diff --git a/devito/ir/support/basic.py b/devito/ir/support/basic.py index 4b188a3130..05286810fc 100644 --- a/devito/ir/support/basic.py +++ b/devito/ir/support/basic.py @@ -1002,6 +1002,12 @@ def writes(self): return as_mapper(self.writes_gen(), key=lambda i: i.function) + @cached_property + def writes_tensor(self): + initialized = frozenset(e.lhs.function for e in self.exprs + if not e.is_Reduction and e.is_scalar) + return frozenset(self.writes) - initialized + @memoized_generator def reads_explicit_gen(self): """ @@ -1124,11 +1130,6 @@ def has_barrier(self): """True if the Scope contains a fence-like control-flow object.""" return any(isinstance(e.rhs, (Fence, CriticalRegion)) for e in self.exprs) - @cached_property - def initialized(self): - return frozenset(e.lhs.function for e in self.exprs - if not e.is_Reduction and e.is_scalar) - def getreads(self, function): return as_tuple(self.reads.get(function)) diff --git a/devito/passes/iet/mpi.py b/devito/passes/iet/mpi.py index 2de99ee002..db835400ae 100644 --- a/devito/passes/iet/mpi.py +++ b/devito/passes/iet/mpi.py @@ -293,17 +293,16 @@ def _mark_overlappable(iet): scope = Scope([n.expr for n in exprs]) - for dep in scope.d_all_gen(): - if dep.function in hs.functions: - cause = dep.cause & hs.dimensions - if any(dep.distance_mapper[d] is S.Infinity for d in cause): - # E.g., dependencies across PARALLEL iterations - # for x - # for y - # ... = ... f[x, y-1] ... - # for y - # f[x, y] = ... - break + for dep in scope.d_all_gen(writes=hs.functions): + cause = dep.cause & hs.dimensions + if any(dep.distance_mapper[d] is S.Infinity for d in cause): + # E.g., dependencies across PARALLEL iterations + # for x + # for y + # ... = ... f[x, y-1] ... + # for y + # f[x, y] = ... + break else: # All good -- we can perform comp/comm overlap! found.append(hs) @@ -507,7 +506,7 @@ def rule1(dep, loc_indices): for d, v in loc_indices.items()) for f, v in hsf.fmapper.items(): - for dep in scope.d_flow.project(f): + for dep in scope.d_flow_gen(writes=f): if not rule0(dep) and not rule1(dep, v.loc_indices): return False From fe171f45c2882c4fe13dd1d13f23f0ec66730b17 Mon Sep 17 00:00:00 2001 From: Fabio Luporini Date: Thu, 9 Apr 2026 16:47:56 +0100 Subject: [PATCH 11/45] compiler: include ClusterGroup ispace in equality semantics --- devito/ir/clusters/cluster.py | 11 +++++++++++ tests/test_ir.py | 20 ++++++++++++++++++++ 2 files changed, 31 insertions(+) diff --git a/devito/ir/clusters/cluster.py b/devito/ir/clusters/cluster.py index 546a8901f5..041af4e1b1 100644 --- a/devito/ir/clusters/cluster.py +++ b/devito/ir/clusters/cluster.py @@ -568,6 +568,17 @@ def __new__(cls, clusters, ispace=None): return obj + def __eq__(self, other): + return (isinstance(other, ClusterGroup) and + super().__eq__(other) and + self._ispace == other._ispace) + + def __ne__(self, other): + return not self == other + + def __hash__(self): + return hash((tuple(self), self._ispace)) + @classmethod def concatenate(cls, *cgroups): return list(chain(*cgroups)) diff --git a/tests/test_ir.py b/tests/test_ir.py index 5ee01273ff..7611d39dfc 100644 --- a/tests/test_ir.py +++ b/tests/test_ir.py @@ -7,6 +7,7 @@ Constant, Dimension, Eq, Function, Grid, Inc, Operator, SubDimension, TimeFunction, switchconfig ) +from devito.ir.clusters import Cluster, ClusterGroup from devito.ir.cgen import ccode from devito.ir.equations import LoweredEq from devito.ir.equations.algorithms import dimension_sort @@ -1161,6 +1162,25 @@ def test_dimension_sort(self, expr, expected): assert list(dimension_sort(expr)) == eval(expected) +class TestClusterGroup: + + def test_eq_hash_include_ispace(self): + grid = Grid(shape=(4,)) + x, = grid.dimensions + + f = Function(name='f', grid=grid) + cluster = Cluster(Eq(f[x], 1)) + + ispace0 = IterationSpace([Interval(x, 0, 0)], directions={x: Forward}) + ispace1 = IterationSpace([Interval(x, 0, 0)], directions={x: Backward}) + + cgroup0 = ClusterGroup((cluster,), ispace0) + cgroup1 = ClusterGroup((cluster,), ispace1) + + assert cgroup0 != cgroup1 + assert len({cgroup0, cgroup1}) == 2 + + class TestGuards: def test_guard_overflow(self): From 16aa3d9f475d8692ef89c7dd72b8cb496bc1c783 Mon Sep 17 00:00:00 2001 From: Fabio Luporini Date: Thu, 9 Apr 2026 17:07:42 +0100 Subject: [PATCH 12/45] compiler: Add update_args= contract to spare compilation time --- devito/passes/iet/definitions.py | 4 ++-- devito/passes/iet/engine.py | 17 +++++++++++------ devito/passes/iet/instrument.py | 4 ++-- devito/passes/iet/misc.py | 6 +++--- tests/test_iet.py | 23 ++++++++++++++++++++++- 5 files changed, 40 insertions(+), 14 deletions(-) diff --git a/devito/passes/iet/definitions.py b/devito/passes/iet/definitions.py index 2f1cce8f10..bef8f0f1f4 100644 --- a/devito/passes/iet/definitions.py +++ b/devito/passes/iet/definitions.py @@ -518,7 +518,7 @@ def place_definitions(self, iet, globs=None, **kwargs): 'globals': as_tuple(globs), 'includes': as_tuple(sorted(storage.includes))} - @iet_pass + @iet_pass(updates_args=False) def place_casts(self, iet, **kwargs): """ Create a new IET with the necessary type casts. @@ -669,7 +669,7 @@ def place_transfers(self, iet, data_movs=None, ctx=None, **kwargs): return iet, {'efuncs': efuncs} - @iet_pass + @iet_pass(updates_args=False) def place_devptr(self, iet, **kwargs): """ Transform `iet` such that device pointers are used in DeviceCalls. diff --git a/devito/passes/iet/engine.py b/devito/passes/iet/engine.py index ad744d19c6..0e92de1ec2 100644 --- a/devito/passes/iet/engine.py +++ b/devito/passes/iet/engine.py @@ -125,7 +125,7 @@ def sync_mapper(self): return found - def apply(self, func, **kwargs): + def apply(self, func, *, updates_args=True, **kwargs): """ Apply `func` to all nodes in the Graph. This changes the state of the Graph. """ @@ -158,9 +158,10 @@ def apply(self, func, **kwargs): efuncs[i] = efunc efuncs.update(dict([(i.name, i) for i in new_efuncs])) - # Update the parameters / arguments lists since `func` may have - # introduced or removed objects - efuncs = update_args(efunc, efuncs, dag) + # Update the parameters / arguments lists if the pass may have + # introduced or removed objects. + if updates_args: + efuncs = update_args(efunc, efuncs, dag) # Minimize code size if len(efuncs) > len(self.efuncs): @@ -205,13 +206,16 @@ def filter(self, key): ) -def iet_pass(func): +def iet_pass(func=None, *, updates_args=True): + if func is None: + return partial(iet_pass, updates_args=updates_args) + if isinstance(func, tuple): assert len(func) == 2 and func[0] is iet_visit call = lambda graph: graph.visit func = func[1] else: - call = lambda graph: graph.apply + call = lambda graph: partial(graph.apply, updates_args=updates_args) @wraps(func) def wrapper(*args, **kwargs): @@ -230,6 +234,7 @@ def wrapper(*args, **kwargs): # Instance method case self, graph = args return maybe_timed(call(graph), func.__name__)(partial(func, self), **kwargs) + return wrapper diff --git a/devito/passes/iet/instrument.py b/devito/passes/iet/instrument.py index 9bd1ce2134..7251683f0f 100644 --- a/devito/passes/iet/instrument.py +++ b/devito/passes/iet/instrument.py @@ -27,7 +27,7 @@ def instrument(graph, **kwargs): sync_sections(graph, **kwargs) -@iet_pass +@iet_pass(updates_args=False) def track_subsections(iet, **kwargs): """ Add sub-Sections to the `profiler`. Sub-Sections include: @@ -122,7 +122,7 @@ def instrument_sections(iet, **kwargs): return piet, {'headers': headers} -@iet_pass +@iet_pass(updates_args=False) def sync_sections(iet, langbb=None, profiler=None, **kwargs): """ Wrap sections within global barriers if deemed necessary by the profiler. diff --git a/devito/passes/iet/misc.py b/devito/passes/iet/misc.py index 0a631bf3a2..1d2324e3d1 100644 --- a/devito/passes/iet/misc.py +++ b/devito/passes/iet/misc.py @@ -29,7 +29,7 @@ ] -@iet_pass +@iet_pass(updates_args=False) def avoid_denormals(iet, platform=None, **kwargs): """ Introduce nodes in the Iteration/Expression tree that will expand to C @@ -60,7 +60,7 @@ def avoid_denormals(iet, platform=None, **kwargs): return iet, {'includes': ('xmmintrin.h', 'pmmintrin.h')} -@iet_pass +@iet_pass(updates_args=False) def hoist_prodders(iet): """ Move Prodders within the outer levels of an Iteration tree. @@ -151,7 +151,7 @@ def generate_macros(graph, **kwargs): _generate_macros(graph, tracker={}, **kwargs) -@iet_pass +@iet_pass(updates_args=False) def _generate_macros(iet, tracker=None, langbb=None, printer=CPrinter, **kwargs): # Derive the Macros necessary for the FIndexeds iet = _generate_macros_findexeds(iet, tracker=tracker, **kwargs) diff --git a/tests/test_iet.py b/tests/test_iet.py index e8e8f8444f..129c895a3e 100644 --- a/tests/test_iet.py +++ b/tests/test_iet.py @@ -14,7 +14,8 @@ ElementalFunction, FindSymbols, Iteration, KernelLaunch, Lambda, List, Switch, Transformer, filter_iterations, make_efunc, retrieve_iteration_tree ) -from devito.passes.iet.engine import Graph +from devito.passes.iet import engine as iet_engine +from devito.passes.iet.engine import Graph, iet_pass from devito.passes.iet.languages.C import CDataManager from devito.symbolics import ( FLOAT, Byref, Class, FieldFromComposite, InlineIf, ListInitializer, Macro, SizeOf, @@ -539,6 +540,26 @@ def test_complex_array(): "float _Complex **restrict a_vec __attribute__ ((aligned (64)));" +def test_iet_pass_skip_update_args(monkeypatch): + x = Symbol(name='x') + y = Symbol(name='y') + + foo = Callable('foo', DummyExpr(x, y), 'void', parameters=(x, y)) + graph = Graph(foo) + + @iet_pass(updates_args=False) + def inject_expr(iet): + body = iet.body._rebuild(body=iet.body.body + (DummyExpr(x, x),)) + return iet._rebuild(body=body), {} + + monkeypatch.setattr(iet_engine, 'update_args', + lambda *args, **kwargs: pytest.fail("update_args called")) + + inject_expr(graph) + + assert graph.root.parameters is foo.parameters + + def test_special_array_definition(): class MyArray(Array): From f29f238da4fd90e016f55a32699d78e50974c7b2 Mon Sep 17 00:00:00 2001 From: Fabio Luporini Date: Fri, 10 Apr 2026 13:08:18 +0100 Subject: [PATCH 13/45] compiler: Add heuristic for topofuse='maximal' --- devito/passes/clusters/derivatives.py | 20 ++++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/devito/passes/clusters/derivatives.py b/devito/passes/clusters/derivatives.py index fcbdbd5b8a..b8c4108b83 100644 --- a/devito/passes/clusters/derivatives.py +++ b/devito/passes/clusters/derivatives.py @@ -6,7 +6,7 @@ from devito.finite_differences import IndexDerivative, Weights from devito.ir import Backward, Forward, Interval, IterationSpace, Queue from devito.passes.clusters.fusion import fuse -from devito.symbolics import BasicWrapperMixin, reuse_if_untouched, uxreplace +from devito.symbolics import BasicWrapperMixin, reuse_if_untouched, search, uxreplace from devito.tools import infer_dtype, timed_pass from devito.types import Eq, Inc, Indexed, Symbol @@ -15,13 +15,16 @@ @timed_pass() def lower_index_derivatives(clusters, mode=None, **kwargs): + max_depth = _max_index_derivative_depth(clusters) clusters, weights, mapper = _lower_index_derivatives(clusters, **kwargs) if not weights: return clusters if mode != 'noop': - clusters = fuse(clusters, toposort='maximal') + for _ in range(max_depth): + clusters = fuse(clusters, toposort='nofuse') + clusters = fuse(clusters, toposort=False) # At this point we can detect redundancies induced by inner derivatives that # previously were just not detectable via e.g. plain CSE. For example, if @@ -258,3 +261,16 @@ def callback(self, clusters, prefix, subs0=None, seen=None): seen.update(processed) return processed + + +# *** Utils + + +def _max_index_derivative_depth(clusters): + max_depth = 0 + + for c in clusters: + for i in search(c.exprs, IndexDerivative): + max_depth = max(max_depth, i.depth) + + return max_depth From acd5956c03a36cceb79aafbda0cc09122d9a8ac7 Mon Sep 17 00:00:00 2001 From: Fabio Luporini Date: Fri, 10 Apr 2026 14:10:18 +0100 Subject: [PATCH 14/45] misc: Patch NVIDIA_VISIBLE_DEVICES and DeviceID --- devito/arch/archinfo.py | 1 + devito/operator/operator.py | 8 +++++--- tests/test_gpu_common.py | 2 ++ 3 files changed, 8 insertions(+), 3 deletions(-) diff --git a/devito/arch/archinfo.py b/devito/arch/archinfo.py index ea4025a3ab..0eb1bed0fa 100644 --- a/devito/arch/archinfo.py +++ b/devito/arch/archinfo.py @@ -497,6 +497,7 @@ def parse_product_arch(): def get_visible_devices(): device_vars = ( 'CUDA_VISIBLE_DEVICES', + 'NVIDIA_VISIBLE_DEVICES', 'ROCR_VISIBLE_DEVICES', 'HIP_VISIBLE_DEVICES' ) diff --git a/devito/operator/operator.py b/devito/operator/operator.py index 66e9857132..8a7ceaf259 100644 --- a/devito/operator/operator.py +++ b/devito/operator/operator.py @@ -1423,11 +1423,13 @@ def _physical_deviceid(self): if isinstance(self.platform, Device): # Get the physical device ID (as CUDA_VISIBLE_DEVICES may be set) logical_deviceid = self.get('deviceid', -1) + visible_device_var, visible_devices = get_visible_devices() if logical_deviceid < 0: rank = self.comm.Get_rank() if self.comm != MPI.COMM_NULL else 0 - logical_deviceid = rank - - visible_device_var, visible_devices = get_visible_devices() + if visible_devices is None: + logical_deviceid = rank + else: + logical_deviceid = rank % len(visible_devices) if visible_devices is None: return logical_deviceid elif len(visible_devices) == 1: diff --git a/tests/test_gpu_common.py b/tests/test_gpu_common.py index e84d0df5d8..4bde13a7fb 100644 --- a/tests/test_gpu_common.py +++ b/tests/test_gpu_common.py @@ -82,6 +82,8 @@ class TestDeviceID: @pytest.mark.parametrize('env_variables', [{"CUDA_VISIBLE_DEVICES": "1"}, {"CUDA_VISIBLE_DEVICES": "1,2"}, {"CUDA_VISIBLE_DEVICES": "1,0"}, + {"NVIDIA_VISIBLE_DEVICES": "1"}, + {"NVIDIA_VISIBLE_DEVICES": "1,2"}, {"ROCR_VISIBLE_DEVICES": "1"}, {"HIP_VISIBLE_DEVICES": " 1"}]) def test_visible_devices(self, env_variables): From 815bd7065dec846ee49e4adb9bfbeb0767c7492c Mon Sep 17 00:00:00 2001 From: Fabio Luporini Date: Wed, 15 Apr 2026 09:26:56 +0100 Subject: [PATCH 15/45] tools: Add DefaultFrozenDict --- devito/tools/data_structures.py | 37 +++++++++++++++++++++++++++++++++ tests/test_tools.py | 25 ++++++++++++++++++++++ 2 files changed, 62 insertions(+) diff --git a/devito/tools/data_structures.py b/devito/tools/data_structures.py index d875878d02..2bf901c3f4 100644 --- a/devito/tools/data_structures.py +++ b/devito/tools/data_structures.py @@ -14,6 +14,7 @@ __all__ = [ 'DAG', 'Bunch', + 'DefaultFrozenDict', 'DefaultOrderedDict', 'EnrichedTuple', 'MemoryEstimate', @@ -672,6 +673,42 @@ def __hash__(self): return self._hash +class DefaultFrozenDict(frozendict): + """ + An immutable mapper that returns a configured default value for missing + keys when accessed via ``obj[key]``. + + Unlike :class:`collections.defaultdict`, the mapping remains immutable and + missing-key access does not mutate internal state. The ``get`` method + preserves the standard dictionary semantics, defaulting to ``None`` unless + the caller provides an explicit fallback. + """ + + _sentinel = object() + + def __init__(self, *args, default=_sentinel, **kwargs): + self._default = default + super().__init__(*args, **kwargs) + + def __getitem__(self, key): + try: + return self._dict[key] + except KeyError: + if self._default is self._sentinel: + raise + + if callable(self._default): + return self._default() + else: + return self._default + + def get(self, key, default=None): + return self._dict.get(key, default) + + def copy(self, **add_or_replace): + return self.__class__(self, default=self._default, **add_or_replace) + + class MemoryEstimate(frozendict): """ An immutable mapper for a memory estimate, providing the estimated memory diff --git a/tests/test_tools.py b/tests/test_tools.py index bef55a82c6..14ccd990b9 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -7,6 +7,7 @@ from devito import Eq, Operator, switchenv from devito.tools import ( + DefaultFrozenDict, CacheInstances, UnboundedMultiTuple, UnboundTuple, ctypes_to_cstr, filter_ordered, toposort, transitive_closure ) @@ -61,6 +62,30 @@ def test_transitive_closure(): assert mapper == {a: d, b: d, c: d, f: e} +def test_default_frozen_dict(): + mapper = DefaultFrozenDict({'a': 'b'}, default='c') + + assert mapper['a'] == 'b' + assert mapper['d'] == 'c' + assert mapper.get('d') is None + assert mapper.get('d', 'e') == 'e' + + copied = mapper.copy(c='d') + assert copied['c'] == 'd' + assert copied['e'] == 'c' + + +def test_default_frozen_dict_factory(): + mapper = DefaultFrozenDict(default=lambda: []) + + v0 = mapper[a] + v1 = mapper[b] + + assert v0 == [] + assert v1 == [] + assert v0 is not v1 + + def test_loops_in_transitive_closure(): a = Symbol('a') b = Symbol('b') From 0c060e5f41f02d3a5b80edd887139e69d67ad21b Mon Sep 17 00:00:00 2001 From: Fabio Luporini Date: Wed, 15 Apr 2026 11:47:26 +0100 Subject: [PATCH 16/45] compiler: Remove dead NodesExprs.dspace --- devito/ir/stree/algorithms.py | 2 +- devito/ir/stree/tree.py | 3 +-- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/devito/ir/stree/algorithms.py b/devito/ir/stree/algorithms.py index d4a761dfc8..68fc697d3e 100644 --- a/devito/ir/stree/algorithms.py +++ b/devito/ir/stree/algorithms.py @@ -111,7 +111,7 @@ def stree_build(clusters, profiler=None, **kwargs): else: parent = tip - NodeExprs(exprs, c.ispace, c.dspace, c.ops, c.traffic, parent) + NodeExprs(exprs, c.ispace, c.ops, c.traffic, parent) # Nest within a NodeSection if possible if profiler is None or \ diff --git a/devito/ir/stree/tree.py b/devito/ir/stree/tree.py index e033c9fd15..96e498396d 100644 --- a/devito/ir/stree/tree.py +++ b/devito/ir/stree/tree.py @@ -115,11 +115,10 @@ class NodeExprs(ScheduleTree): is_Exprs = True - def __init__(self, exprs, ispace, dspace, ops, traffic, parent=None): + def __init__(self, exprs, ispace, ops, traffic, parent=None): super().__init__(parent) self.exprs = exprs self.ispace = ispace - self.dspace = dspace self.ops = ops self.traffic = traffic From 5bb7e6dbb5cf7373b64f32322c5b25238d2e6519 Mon Sep 17 00:00:00 2001 From: Fabio Luporini Date: Fri, 24 Apr 2026 13:55:58 +0100 Subject: [PATCH 17/45] tools: Add reuse_if_unchanged and exploit it --- devito/ir/equations/equation.py | 3 ++- devito/tools/memoization.py | 42 +++++++++++++++++++++++++++++++-- 2 files changed, 42 insertions(+), 3 deletions(-) diff --git a/devito/ir/equations/equation.py b/devito/ir/equations/equation.py index e8221a2364..283b312daa 100644 --- a/devito/ir/equations/equation.py +++ b/devito/ir/equations/equation.py @@ -10,7 +10,7 @@ GuardFactor, Interval, IntervalGroup, IterationSpace, Stencil, detect_accesses ) from devito.symbolics import IntDiv, limits_mapper, retrieve_accesses, uxreplace -from devito.tools import Pickable, Tag, filter_sorted, frozendict +from devito.tools import Pickable, Tag, filter_sorted, frozendict, reuse_if_unchanged from devito.types import Eq, Inc, ReduceMax, ReduceMin, ReduceMinMax, relational_min __all__ = [ @@ -359,6 +359,7 @@ class ClusterizedEq(IREq): These two properties make a ClusterizedEq suitable for use in a Cluster. """ + @reuse_if_unchanged('__rkwargs__') def __new__(cls, *args, **kwargs): if len(args) == 1: # origin: ClusterizedEq(expr, **kwargs) diff --git a/devito/tools/memoization.py b/devito/tools/memoization.py index e92101763a..f631584f4d 100644 --- a/devito/tools/memoization.py +++ b/devito/tools/memoization.py @@ -1,9 +1,15 @@ from collections.abc import Callable, Hashable -from functools import lru_cache, partial +from functools import lru_cache, partial, wraps from itertools import tee from typing import TypeVar -__all__ = ['CacheInstances', 'memoized_func', 'memoized_generator', 'memoized_meth'] +__all__ = [ + 'CacheInstances', + 'memoized_func', + 'memoized_generator', + 'memoized_meth', + 'reuse_if_unchanged' +] class memoized_func: @@ -217,3 +223,35 @@ def clear_caches() -> None: Clears all IR instance caches. """ CacheInstancesMeta.clear_caches() + + +def reuse_if_unchanged(fields): + """ + Decorator for wrapper-style constructors that should return the original + object when called as ``Cls(existing_obj, **same_metadata)``. + + The wrapped callable is assumed to be a classmethod-like constructor + receiving ``cls`` as first argument. The fast path triggers only when: + + * the constructor is called with exactly one positional argument; + * that argument is already an exact instance of ``cls``; + * any explicitly provided metadata fields are the same objects as the + corresponding attributes on the input object. + """ + def decorator(func): + @wraps(func) + def wrapper(cls, *args, **kwargs): + if len(args) == 1: + input_obj = args[0] + if type(input_obj) is cls: + names = getattr(cls, fields) if isinstance(fields, str) else fields + for name in names: + if name in kwargs and kwargs[name] is not getattr(input_obj, name, None): + break + else: + return input_obj + return func(cls, *args, **kwargs) + + return wrapper + + return decorator From bccca1f3d961ab27174a159d449ef812d6930edc Mon Sep 17 00:00:00 2001 From: Fabio Luporini Date: Mon, 27 Apr 2026 11:11:40 +0100 Subject: [PATCH 18/45] compiler: Split into EqBlock and Cluster --- devito/ir/clusters/cluster.py | 296 ++++++++++++++++++++------------ devito/ir/equations/equation.py | 8 +- devito/tools/utils.py | 25 ++- 3 files changed, 214 insertions(+), 115 deletions(-) diff --git a/devito/ir/clusters/cluster.py b/devito/ir/clusters/cluster.py index 041af4e1b1..3abd154877 100644 --- a/devito/ir/clusters/cluster.py +++ b/devito/ir/clusters/cluster.py @@ -14,7 +14,7 @@ from devito.mpi.halo_scheme import HaloScheme, HaloTouch from devito.mpi.reduction_scheme import DistReduce from devito.symbolics import estimate_cost -from devito.tools import as_tuple, filter_ordered, flatten, infer_dtype +from devito.tools import CacheInstances, as_tuple, filter_ordered, flatten, infer_dtype from devito.types import ( CriticalRegion, Fence, Indexed, PhaseMarker, TensorMove, ThreadArrive, ThreadCommit, ThreadPoolSync, ThreadWait, WeakFence @@ -23,125 +23,45 @@ __all__ = ["Cluster", "ClusterGroup"] -class Cluster: +class EqBlock(CacheInstances): """ - A Cluster is an ordered sequence of expressions in an IterationSpace. - - Parameters - ---------- - exprs : expr-like or list of expr-like - An ordered sequence of expressions computing a tensor. - ispace : IterationSpace, optional - The Cluster iteration space. - guards : dict, optional - Mapper from Dimensions to expr-like, representing the conditions under - which the Cluster should be computed. - properties : dict, optional - Mapper from Dimensions to Property, describing the Cluster properties - such as its parallel Dimensions. - syncs : dict, optional - Mapper from Dimensions to lists of SyncOps, that is ordered sequences of - synchronization operations that must be performed in order to compute the - Cluster asynchronously. - halo_scheme : HaloScheme, optional - The halo exchanges required by the Cluster. + A sequence of equations with associated metadata. """ + @classmethod + def _preprocess_args(cls, exprs, ispace=null_ispace, guards=None, + properties=None, syncs=None, halo_scheme=None): + exprs = tuple(ClusterizedEq(e, ispace=ispace) for e in as_tuple(exprs)) + guards = Guards(guards or {}) + properties = Properties(properties or {}) + syncs = normalize_syncs(syncs or {}) + + return (exprs, ispace, guards, properties, syncs, halo_scheme), {} + def __init__(self, exprs, ispace=null_ispace, guards=None, properties=None, syncs=None, halo_scheme=None): - self._exprs = tuple(ClusterizedEq(e, ispace=ispace) for e in as_tuple(exprs)) + self._exprs = exprs self._ispace = ispace - self._guards = Guards(guards or {}) - self._syncs = normalize_syncs(syncs or {}) - - properties = Properties(properties or {}) - properties = tailor_properties(properties, ispace) - self._properties = update_properties(properties, self.exprs) - + self._guards = guards + self._syncs = syncs self._halo_scheme = halo_scheme - def __repr__(self): - return "Cluster([{}])".format(('\n' + ' '*9).join(f'{i}' for i in self.exprs)) - - @classmethod - def from_clusters(cls, *clusters): - """ - Build a new Cluster from a sequence of pre-existing Clusters with - compatible IterationSpace. - """ - assert len(clusters) > 0 - root = clusters[0] - - if len(clusters) == 1: - return root - - if not all(root.ispace.is_compatible(c.ispace) for c in clusters): - raise ValueError("Cannot build a Cluster from Clusters with " - "incompatible IterationSpace") - if not all(root.guards == c.guards for c in clusters): - raise ValueError("Cannot build a Cluster from Clusters with " - "non-homogeneous guards") - - writes = set().union(*[c.scope.writes for c in clusters]) - reads = set().union(*[c.scope.reads for c in clusters]) - if any(f._mem_shared for f in writes & reads): - raise ValueError("Cannot build a Cluster from Clusters with " - "read-write conflicts on shared-memory Functions") - - exprs = chain(*[c.exprs for c in clusters]) - ispace = IterationSpace.union(*[c.ispace for c in clusters]) - - guards = root.guards - - properties = reduce_properties(clusters) - - try: - syncs = normalize_syncs(*[c.syncs for c in clusters]) - except ValueError as e: - raise ValueError( - "Cannot build a Cluster from Clusters with " - "non-compatible synchronization operations" - ) from e - - halo_scheme = HaloScheme.union([c.halo_scheme for c in clusters]) - - return Cluster(exprs, ispace, guards, properties, syncs, halo_scheme) - - def rebuild(self, *args, **kwargs): - """ - Build a new Cluster from the attributes given as keywords. All other - attributes are taken from ``self``. - """ - # Shortcut for backwards compatibility - if args: - if len(args) != 1: - raise ValueError("rebuild takes at most one positional argument (exprs)") - if kwargs.get('exprs'): - raise ValueError("`exprs` provided both as arg and kwarg") - kwargs['exprs'] = args[0] - - exprs = kwargs.get('exprs', self.exprs) - ispace = kwargs.get('ispace', self.ispace) - guards = kwargs.get('guards', self.guards) - properties = kwargs.get('properties', self.properties) - syncs = kwargs.get('syncs', self.syncs) - halo_scheme = kwargs.get('halo_scheme', self.halo_scheme) + properties = tailor_properties(properties, ispace) + self._properties = update_properties(properties, self._exprs) - if exprs is self.exprs and \ - ispace is self.ispace and \ - guards is self.guards and \ - properties is self.properties and \ - syncs is self.syncs and \ - halo_scheme is self.halo_scheme: - return self + def __eq__(self, other): + return (type(self) is type(other) and + self.exprs == other.exprs and + self.ispace == other.ispace and + self.guards == other.guards and + self.properties == other.properties and + self.syncs == other.syncs and + self.halo_scheme == other.halo_scheme) - return self.__class__(exprs=exprs, - ispace=ispace, - guards=guards, - properties=properties, - syncs=syncs, - halo_scheme=halo_scheme) + def __hash__(self): + return hash((self.exprs, self.ispace, self.guards, self.properties, + self.syncs, self.halo_scheme)) @property def exprs(self): @@ -397,8 +317,8 @@ def dtype(self): performing integer arithmetic are ignored, assuming that they are only carrying out array index calculations. - If two expressions perform calculations with different precision, the - data type with highest precision is returned. + If two expressions perform calculations with different precision, + the data type with highest precision is returned. """ dtypes = set() for i in self.exprs: @@ -414,8 +334,8 @@ def dtype(self): @cached_property def dspace(self): """ - Derive the DataSpace of the Cluster from its expressions, IterationSpace, - and Guards. + Derive the DataSpace of the Cluster from its expressions, + IterationSpace, and Guards. """ accesses = detect_accesses(self.exprs) @@ -541,6 +461,156 @@ def traffic(self): return ret +class Cluster: + + """ + A context-sensitive sequence of equations. + + The structural payload (equations, IterationSpace, ...) lives in the + underlying EqBlock. A Cluster, unlike EqBlock, deliberately keeps identity + semantics because its position in a sequence of Clusters does matter. It + follows that two Cluster instances may share the same EqBlock, but they + remain distinct: Clusters intentionally use object identity for equality + and hashing, so only references to the same Cluster object compare equal. + + Parameters + ---------- + exprs : expr-like or list of expr-like + An ordered sequence of expressions computing a tensor. + ispace : IterationSpace, optional + The Cluster iteration space. + guards : dict, optional + Mapper from Dimensions to expr-like, representing the conditions under + which the Cluster should be computed. + properties : dict, optional + Mapper from Dimensions to Property, describing the Cluster properties + such as its parallel Dimensions. + syncs : dict, optional + Mapper from Dimensions to lists of SyncOps, that is ordered sequences of + synchronization operations that must be performed in order to compute the + Cluster asynchronously. + halo_scheme : HaloScheme, optional + The halo exchanges required by the Cluster. + """ + + def __init__(self, exprs, ispace=null_ispace, guards=None, properties=None, + syncs=None, halo_scheme=None): + self._block = EqBlock(exprs, ispace, guards, properties, syncs, halo_scheme) + + def __repr__(self): + return "Cluster([{}])".format(('\n' + ' '*9).join(f'{i}' for i in self.exprs)) + + def __getattr__(self, name): + try: + block = object.__getattribute__(self, '_block') + except AttributeError: + raise AttributeError(name) from None + return getattr(block, name) + + @property + def exprs(self): + return self._block.exprs + + @property + def ispace(self): + return self._block.ispace + + @property + def guards(self): + return self._block.guards + + @property + def properties(self): + return self._block.properties + + @property + def syncs(self): + return self._block.syncs + + @property + def halo_scheme(self): + return self._block.halo_scheme + + @classmethod + def from_clusters(cls, *clusters): + """ + Build a new Cluster from a sequence of pre-existing Clusters with + compatible IterationSpace. + """ + assert len(clusters) > 0 + root = clusters[0] + + if len(clusters) == 1: + return root + + if not all(root.ispace.is_compatible(c.ispace) for c in clusters): + raise ValueError("Cannot build a Cluster from Clusters with " + "incompatible IterationSpace") + if not all(root.guards == c.guards for c in clusters): + raise ValueError("Cannot build a Cluster from Clusters with " + "non-homogeneous guards") + + writes = set().union(*[c.scope.writes for c in clusters]) + reads = set().union(*[c.scope.reads for c in clusters]) + if any(f._mem_shared for f in writes & reads): + raise ValueError("Cannot build a Cluster from Clusters with " + "read-write conflicts on shared-memory Functions") + + exprs = chain(*[c.exprs for c in clusters]) + ispace = IterationSpace.union(*[c.ispace for c in clusters]) + + guards = root.guards + + properties = reduce_properties(clusters) + + try: + syncs = normalize_syncs(*[c.syncs for c in clusters]) + except ValueError as e: + raise ValueError( + "Cannot build a Cluster from Clusters with " + "non-compatible synchronization operations" + ) from e + + halo_scheme = HaloScheme.union([c.halo_scheme for c in clusters]) + + return Cluster(exprs, ispace, guards, properties, syncs, halo_scheme) + + def rebuild(self, *args, **kwargs): + """ + Build a new Cluster from the attributes given as keywords. All other + attributes are taken from ``self``. + """ + # Shortcut for backwards compatibility + if args: + if len(args) != 1: + raise ValueError("rebuild takes at most one positional argument (exprs)") + if kwargs.get('exprs'): + raise ValueError("`exprs` provided both as arg and kwarg") + kwargs['exprs'] = args[0] + + exprs = kwargs.get('exprs', self.exprs) + ispace = kwargs.get('ispace', self.ispace) + guards = kwargs.get('guards', self.guards) + properties = kwargs.get('properties', self.properties) + syncs = kwargs.get('syncs', self.syncs) + halo_scheme = kwargs.get('halo_scheme', self.halo_scheme) + + if exprs is self.exprs and \ + ispace is self.ispace and \ + guards is self.guards and \ + properties is self.properties and \ + syncs is self.syncs and \ + halo_scheme is self.halo_scheme: + return self + + return self.__class__(exprs=exprs, + ispace=ispace, + guards=guards, + properties=properties, + syncs=syncs, + halo_scheme=halo_scheme) + + class ClusterGroup(tuple): """ diff --git a/devito/ir/equations/equation.py b/devito/ir/equations/equation.py index 283b312daa..52ac36473f 100644 --- a/devito/ir/equations/equation.py +++ b/devito/ir/equations/equation.py @@ -10,7 +10,9 @@ GuardFactor, Interval, IntervalGroup, IterationSpace, Stencil, detect_accesses ) from devito.symbolics import IntDiv, limits_mapper, retrieve_accesses, uxreplace -from devito.tools import Pickable, Tag, filter_sorted, frozendict, reuse_if_unchanged +from devito.tools import ( + Pickable, Tag, as_hashable, filter_sorted, frozendict, reuse_if_unchanged +) from devito.types import Eq, Inc, ReduceMax, ReduceMin, ReduceMinMax, relational_min __all__ = [ @@ -72,6 +74,10 @@ def state(self): def operation(self): return self._operation + def _hashable_content(self): + return (super()._hashable_content() + + tuple(as_hashable(getattr(self, i)) for i in self.__rkwargs__)) + @property def is_Reduction(self): return self.operation in (OpInc, OpMin, OpMax, OpMinMax) diff --git a/devito/tools/utils.py b/devito/tools/utils.py index 91b5bcdbf7..470be7e79e 100644 --- a/devito/tools/utils.py +++ b/devito/tools/utils.py @@ -1,6 +1,6 @@ import types from collections import OrderedDict -from collections.abc import Iterable +from collections.abc import Iterable, Mapping from functools import reduce, wraps from itertools import chain, combinations, groupby, product, zip_longest from operator import attrgetter, mul @@ -10,6 +10,7 @@ __all__ = [ 'all_equal', + 'as_hashable', 'as_list', 'as_mapper', 'as_set', @@ -87,6 +88,28 @@ def as_tuple(item, type=None, length=None): return t +def as_hashable(item): + """ + Convert common containers into a hashable representation. + + Unknown unhashable objects fall back to identity, avoiding false cache hits. + """ + if isinstance(item, Mapping): + items = ((as_hashable(k), as_hashable(v)) for k, v in item.items()) + return tuple(sorted(items, key=repr)) + if isinstance(item, (tuple, list)): + return tuple(as_hashable(i) for i in item) + if isinstance(item, (set, frozenset)): + return tuple(sorted((as_hashable(i) for i in item), key=repr)) + + try: + hash(item) + except TypeError: + return (type(item), id(item)) + else: + return item + + def as_mapper(iterable, key=None, get=None): """ Rearrange an iterable into a dictionary of lists in which keys are From c8b007b082877f6f79f58c5a519f4b4b8d18d601 Mon Sep 17 00:00:00 2001 From: Fabio Luporini Date: Tue, 28 Apr 2026 09:39:34 +0100 Subject: [PATCH 19/45] compiler: Stash hash were essential for compilation performance --- devito/ir/clusters/cluster.py | 5 ++++- devito/ir/clusters/visitors.py | 3 ++- devito/ir/support/space.py | 17 +++++++++++------ devito/tools/memoization.py | 20 ++++++++++++++++++++ 4 files changed, 37 insertions(+), 8 deletions(-) diff --git a/devito/ir/clusters/cluster.py b/devito/ir/clusters/cluster.py index 3abd154877..717b508b92 100644 --- a/devito/ir/clusters/cluster.py +++ b/devito/ir/clusters/cluster.py @@ -14,7 +14,9 @@ from devito.mpi.halo_scheme import HaloScheme, HaloTouch from devito.mpi.reduction_scheme import DistReduce from devito.symbolics import estimate_cost -from devito.tools import CacheInstances, as_tuple, filter_ordered, flatten, infer_dtype +from devito.tools import ( + CacheInstances, as_tuple, cached_hash, filter_ordered, flatten, infer_dtype +) from devito.types import ( CriticalRegion, Fence, Indexed, PhaseMarker, TensorMove, ThreadArrive, ThreadCommit, ThreadPoolSync, ThreadWait, WeakFence @@ -646,6 +648,7 @@ def __eq__(self, other): def __ne__(self, other): return not self == other + @cached_hash def __hash__(self): return hash((tuple(self), self._ispace)) diff --git a/devito/ir/clusters/visitors.py b/devito/ir/clusters/visitors.py index a5b6587a9f..da0a62344a 100644 --- a/devito/ir/clusters/visitors.py +++ b/devito/ir/clusters/visitors.py @@ -2,7 +2,7 @@ from itertools import groupby from devito.ir.support import IterationSpace, null_ispace -from devito.tools import flatten, timed_pass +from devito.tools import cached_hash, flatten, timed_pass __all__ = ['Queue', 'cluster_pass'] @@ -131,6 +131,7 @@ def __eq__(self, other): self.properties == other.properties and self.syncs == other.syncs) + @cached_hash def __hash__(self): return hash((self.intervals, self.sub_iterators, self.directions, self.guards, self.properties, self.syncs)) diff --git a/devito/ir/support/space.py b/devito/ir/support/space.py index a1ef7e50f8..7c9f970108 100644 --- a/devito/ir/support/space.py +++ b/devito/ir/support/space.py @@ -9,7 +9,7 @@ from devito.ir.support.vector import Vector, vmax, vmin from devito.tools import ( CacheInstances, Ordering, Stamp, as_list, as_set, as_tuple, filter_ordered, - flatten, frozendict, is_integer, toposort + cached_hash, flatten, frozendict, is_integer, toposort ) from devito.types import Dimension, ModuloDimension @@ -53,6 +53,7 @@ def __eq__(self, o): is_compatible = __eq__ + @cached_hash def __hash__(self): return hash(self.dim.name) @@ -103,6 +104,7 @@ def _preprocess_args(cls, dim, stamp=S0): def __repr__(self): return f"{self.dim}[Null]{self.stamp}" + @cached_hash def __hash__(self): return hash(self.dim) @@ -167,6 +169,7 @@ def __init__(self, dim, lower=0, upper=0, stamp=S0): def __repr__(self): return f"{self.dim}[{self.lower},{self.upper}]{self.stamp}" + @cached_hash def __hash__(self): return hash((self.dim, self.offsets)) @@ -362,6 +365,7 @@ def __eq__(self, o): def __contains__(self, d): return any(i.dim is d for i in self) + @cached_hash def __hash__(self): return hash((tuple(self), self.relations, self.mode)) @@ -620,6 +624,7 @@ def __eq__(self, other): def __repr__(self): return self._name + @cached_hash def __hash__(self): return hash(self._name) @@ -658,6 +663,7 @@ def __eq__(self, other): return False return self.direction is other.direction and super().__eq__(other) + @cached_hash def __hash__(self): return hash((self.dim, self.offsets, self.direction)) @@ -692,9 +698,6 @@ def __repr__(self): def __eq__(self, other): return isinstance(other, Space) and self.intervals == other.intervals - def __hash__(self): - return hash(self.intervals) - def __len__(self): return len(self.intervals) @@ -758,8 +761,9 @@ def __eq__(self, other): self.intervals == other.intervals and self.parts == other.parts) + @cached_hash def __hash__(self): - return hash((super().__hash__(), self.parts)) + return hash((self.intervals, self.parts)) @classmethod def union(cls, *others): @@ -854,8 +858,9 @@ def __lt__(self, other): """ return len(self.itintervals) < len(other.itintervals) + @cached_hash def __hash__(self): - return hash((super().__hash__(), self.sub_iterators, self.directions)) + return hash((self.intervals, self.sub_iterators, self.directions)) def __contains__(self, d): try: diff --git a/devito/tools/memoization.py b/devito/tools/memoization.py index f631584f4d..e003db45f2 100644 --- a/devito/tools/memoization.py +++ b/devito/tools/memoization.py @@ -5,6 +5,7 @@ __all__ = [ 'CacheInstances', + 'cached_hash', 'memoized_func', 'memoized_generator', 'memoized_meth', @@ -12,6 +13,25 @@ ] +def cached_hash(func): + """ + Cache an immutable object's ``__hash__`` return value in ``_mhash``. + + Warning: avoid explicitly calling a superclass' cached ``__hash__`` on a + subclass instance, as that would stash the superclass hash in ``_mhash``. + """ + @wraps(func) + def wrapper(self): + try: + return self._mhash + except AttributeError: + ret = func(self) + self._mhash = ret + return ret + + return wrapper + + class memoized_func: """ Decorator. Caches a function's return value each time it is called. From 7b33fa65a97185bdbda2addf07041730fdd0c977 Mon Sep 17 00:00:00 2001 From: Fabio Luporini Date: Tue, 28 Apr 2026 10:45:06 +0100 Subject: [PATCH 20/45] compiler: Exploit cached_hash --- devito/ir/support/basic.py | 5 +++-- devito/tools/abc.py | 3 +++ 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/devito/ir/support/basic.py b/devito/ir/support/basic.py index 05286810fc..d9cebed3fe 100644 --- a/devito/ir/support/basic.py +++ b/devito/ir/support/basic.py @@ -14,8 +14,8 @@ retrieve_indexed ) from devito.tools import ( - CacheInstances, Tag, as_mapper, as_tuple, filter_sorted, flatten, is_integer, - memoized_func, memoized_generator, memoized_meth, smart_gt, smart_lt + CacheInstances, Tag, as_mapper, as_tuple, cached_hash, filter_sorted, flatten, + is_integer, memoized_func, memoized_generator, memoized_meth, smart_gt, smart_lt ) from devito.types import ( ComponentAccess, CriticalRegion, Dimension, DimensionTuple, Fence, Function, Symbol, @@ -253,6 +253,7 @@ def __eq__(self, other): self.access == other.access and self.ispace == other.ispace) + @cached_hash def __hash__(self): return hash((self.access, self.mode, self.timestamp, self.ispace)) diff --git a/devito/tools/abc.py b/devito/tools/abc.py index 2e489ac3c0..814eabd7f2 100644 --- a/devito/tools/abc.py +++ b/devito/tools/abc.py @@ -1,5 +1,7 @@ from hashlib import sha1 +from devito.tools.memoization import cached_hash + __all__ = ['Pickable', 'Reconstructable', 'Signer', 'Singleton', 'Stamp', 'Tag'] @@ -34,6 +36,7 @@ def __gt__(self, other): def __ge__(self, other): return self.val >= other.val + @cached_hash def __hash__(self): return hash((self.name, self.val)) From 0baada8d378abd96d9c61192b156f150c7117d91 Mon Sep 17 00:00:00 2001 From: Fabio Luporini Date: Tue, 28 Apr 2026 15:35:38 +0100 Subject: [PATCH 21/45] compiler: Retain original objects whenever possible --- devito/ir/support/guards.py | 21 ++++++++++++--------- devito/ir/support/properties.py | 19 +++++++++++-------- 2 files changed, 23 insertions(+), 17 deletions(-) diff --git a/devito/ir/support/guards.py b/devito/ir/support/guards.py index b8a335b1f4..cd9a5292b2 100644 --- a/devito/ir/support/guards.py +++ b/devito/ir/support/guards.py @@ -272,31 +272,34 @@ class Guards(frozendict): def get(self, d, v=true): return super().get(d, v) + def _rebuild(self, mapper): + return self if mapper == self else Guards(mapper) + def andg(self, d, guard): m = dict(self) if guard == true: - return Guards(m) + return self try: m[d] = simplify_and(m[d], guard) except KeyError: m[d] = guard - return Guards(m) + return self._rebuild(m) def xandg(self, d, guard): m = dict(self) if guard == true: - return Guards(m) + return self try: m[d] = And(m[d], guard) except KeyError: m[d] = guard - return Guards(m) + return self._rebuild(m) def pairwise_or(self, d, *guards): m = dict(self) @@ -311,17 +314,17 @@ def pairwise_or(self, d, *guards): else: m[d] = g - return Guards(m) + return self._rebuild(m) def impose(self, d, guard): m = dict(self) if guard == true: - return Guards(m) + return self m[d] = guard - return Guards(m) + return self._rebuild(m) def popany(self, dims): m = dict(self) @@ -329,12 +332,12 @@ def popany(self, dims): for d in as_tuple(dims): m.pop(d, None) - return Guards(m) + return self._rebuild(m) def filter(self, key): m = {d: v for d, v in self.items() if key(d)} - return Guards(m) + return self._rebuild(m) def as_map(self, d, cls): if cls not in (Le, Lt, Ge, Gt): diff --git a/devito/ir/support/properties.py b/devito/ir/support/properties.py index 8dc759cc73..a835bb3f07 100644 --- a/devito/ir/support/properties.py +++ b/devito/ir/support/properties.py @@ -208,15 +208,18 @@ def __init__(self, *args, **kwargs): def dimensions(self): return tuple(self) + def _rebuild(self, mapper): + return self if mapper == self else Properties(mapper) + def add(self, dims, properties=None): m = dict(self) for d in as_tuple(dims): m[d] = set(self.get(d, [])) | set(as_tuple(properties)) - return Properties(m) + return self._rebuild(m) def filter(self, key): m = {d: v for d, v in self.items() if key(d)} - return Properties(m) + return self._rebuild(m) def drop(self, dims=None, properties=None): if dims is None: @@ -227,7 +230,7 @@ def drop(self, dims=None, properties=None): m.pop(d, None) else: m[d] = self[d] - set(as_tuple(properties)) - return Properties(m) + return self._rebuild(m) def parallelize(self, dims): m = dict(self) @@ -236,13 +239,13 @@ def parallelize(self, dims): v.difference_update({PARALLEL_IF_PVT, PARALLEL_IF_ATOMIC, SEQUENTIAL}) v.add(PARALLEL) m[d] = v - return Properties(m) + return self._rebuild(m) def affine(self, dims): m = dict(self) for d in as_tuple(dims): m[d] = set(self.get(d, [])) | {AFFINE} - return Properties(m) + return self._rebuild(m) def sequentialize(self, dims=None): if dims is None: @@ -250,13 +253,13 @@ def sequentialize(self, dims=None): m = dict(self) for d in as_tuple(dims): m[d] = normalize_properties(set(self.get(d, [])), {SEQUENTIAL}) - return Properties(m) + return self._rebuild(m) def prefetchable(self, dims, v=PREFETCHABLE): m = dict(self) for d in as_tuple(dims): m[d] = self.get(d, set()) | {v} - return Properties(m) + return self._rebuild(m) def block(self, dims, kind='default'): if kind == 'default': @@ -268,7 +271,7 @@ def block(self, dims, kind='default'): m = dict(self) for d in as_tuple(dims): m[d] = set(self.get(d, [])) | {p} - return Properties(m) + return self._rebuild(m) def inbound(self, dims): return self.add(dims, INBOUND) From e3c09da6d5650ef7d29dcc7d924e795e9f571a10 Mon Sep 17 00:00:00 2001 From: Fabio Luporini Date: Wed, 29 Apr 2026 08:59:13 +0100 Subject: [PATCH 22/45] compiler: Minimize reconstructions everywhere --- devito/passes/clusters/misc.py | 2 +- devito/symbolics/manipulation.py | 67 +++++++++++++++++++++++++------- tests/test_symbolics.py | 51 +++++++++++++++++++++++- 3 files changed, 103 insertions(+), 17 deletions(-) diff --git a/devito/passes/clusters/misc.py b/devito/passes/clusters/misc.py index 4530f6c28f..0cd255ce97 100644 --- a/devito/passes/clusters/misc.py +++ b/devito/passes/clusters/misc.py @@ -109,7 +109,7 @@ def optimize_pows(cluster, *args): """ Convert integer powers into Muls, such as ``a**2 => a*a``. """ - return cluster.rebuild(exprs=[pow_to_mul(e) for e in cluster.exprs]) + return cluster.rebuild(exprs=pow_to_mul(cluster.exprs)) class Fission(Queue): diff --git a/devito/symbolics/manipulation.py b/devito/symbolics/manipulation.py index 57d9314e16..85ddfcaeab 100644 --- a/devito/symbolics/manipulation.py +++ b/devito/symbolics/manipulation.py @@ -64,6 +64,9 @@ def uxreplace(expr, rule): Finally, `uxreplace` supports Reconstructable objects, that is, it searches for replacement opportunities inside the Reconstructable's `__rkwargs__`. """ + if not rule: + return expr + return _uxreplace(expr, rule)[0] @@ -129,13 +132,15 @@ def _(iterable, rule): ax, flag = _uxreplace(a, rule) ret.append(ax) changed |= flag - return iterable.__class__(ret), changed + return (iterable.__class__(ret), True) if changed else (iterable, False) @_uxreplace_dispatch.register(EnrichedTuple) def _(iterable, rule): retval, changed = _uxreplace_dispatch(tuple(iterable), rule) - return iterable.__class__(*retval, getters=iterable.getters), changed + if changed: + return iterable.__class__(*retval, getters=iterable.getters), True + return iterable, False @_uxreplace_dispatch.register(dict) @@ -146,7 +151,7 @@ def _(mapper, rule): vx, flag = _uxreplace_dispatch(v, rule) ret[k] = vx changed |= flag - return ret, changed + return (ret, True) if changed else (mapper, False) @singledispatch @@ -282,10 +287,18 @@ def subs_if_composite(expr, subs): Indexed"). Instead, if `subs` consists of just "primitive" expressions, then resort to the much faster `uxreplace`. """ - if all(isinstance(i, (Indexed, IndexDerivative)) for i in subs): + if not subs: + return expr + + if type(expr) is tuple: + return reuse_if_untouched(expr, (subs_if_composite(e, subs) for e in expr)) + elif type(expr) is list: + return reuse_if_untouched(expr, [subs_if_composite(e, subs) for e in expr]) + elif all(isinstance(i, (Indexed, IndexDerivative)) for i in subs): return uxreplace(expr, subs) else: - return expr.subs(subs) + processed = expr.subs(subs) + return expr if processed == expr else processed def xreplace_indices(exprs, mapper, key=None): @@ -304,14 +317,25 @@ def xreplace_indices(exprs, mapper, key=None): callable, apply the replacement to a symbol S if and only if ``key(S)`` gives True. """ - handle = flatten(retrieve_indexed(i) for i in as_tuple(exprs)) + exprs0 = as_tuple(exprs) + + handle = flatten(retrieve_indexed(i) for i in exprs0) if isinstance(key, Iterable): handle = [i for i in handle if i.base.label in key] elif callable(key): handle = [i for i in handle if key(i)] - mapper = dict(zip(handle, [i.xreplace(mapper) for i in handle], strict=True)) - replaced = [uxreplace(i, mapper) for i in as_tuple(exprs)] - return replaced if isinstance(exprs, Iterable) else replaced[0] + mapper = {i: v for i in handle if (v := i.xreplace(mapper)) != i} + if not mapper: + return exprs + + replaced = [uxreplace(i, mapper) for i in exprs0] + + if isinstance(exprs, Iterable): + if len(replaced) == len(exprs0) and all(i is j for i, j in zip(replaced, exprs0)): + return exprs + return replaced + else: + return replaced[0] def _eval_numbers(expr, args): @@ -344,7 +368,9 @@ def flatten_args(args, op, ignore=None): def pow_to_mul(expr): - if q_leaf(expr) or isinstance(expr, Basic): + if type(expr) in (tuple, list): + return reuse_if_untouched(expr, (pow_to_mul(i) for i in expr)) + elif q_leaf(expr) or isinstance(expr, Basic): return expr elif expr.is_Pow: base, exp = expr.as_base_exp() @@ -359,7 +385,7 @@ def pow_to_mul(expr): elif (int(exp) - exp != 0): # Fractional powers also remain untouched, # but at least we traverse the base looking for other Pows - return expr.func(pow_to_mul(base), exp, evaluate=False) + return reuse_if_untouched(expr, (pow_to_mul(base), exp), evaluate=False) elif exp > 0: return UnevalMul(*[pow_to_mul(base)]*int(exp), evaluate=False) elif exp < 0: @@ -383,7 +409,7 @@ def pow_to_mul(expr): except ValueError: pass - return expr.func(*args, evaluate=False) + return reuse_if_untouched(expr, args, evaluate=False) def indexify(expr): @@ -429,10 +455,21 @@ def normalize_args(args): def reuse_if_untouched(expr, args, evaluate=False): """ - Reconstruct `expr` iff any of the provided `args` is different than - the corresponding arg in `expr.args`. + Reconstruct `expr` iff any of the provided `args` is different from + the corresponding arg in `expr.args`, or from the corresponding item + for plain tuples/lists. """ - if all(a is b for a, b in zip(expr.args, args, strict=False)): + args = tuple(args) + + if type(expr) is tuple: + if len(args) == len(expr) and all(a is b for a, b in zip(expr, args)): + return expr + return args + elif type(expr) is list: + if len(args) == len(expr) and all(a is b for a, b in zip(expr, args)): + return expr + return list(args) + elif all(a is b for a, b in zip(expr.args, args, strict=False)): return expr else: return expr.func(*args, evaluate=evaluate) diff --git a/tests/test_symbolics.py b/tests/test_symbolics.py index 5e6a4150e4..ba93e08470 100644 --- a/tests/test_symbolics.py +++ b/tests/test_symbolics.py @@ -18,7 +18,7 @@ INT, BaseCast, CallFromPointer, Cast, DefFunction, FieldFromComposite, FieldFromPointer, IntDiv, ListInitializer, Namespace, ReservedWord, RoundUp, Rvalue, SizeOf, VectorAccess, evalrel, pow_to_mul, retrieve_derivatives, retrieve_functions, - retrieve_indexed, uxreplace + retrieve_indexed, subs_if_composite, uxreplace, xreplace_indices ) from devito.tools import CustomDtype, as_tuple, dtypes_vector_mapper from devito.types import ( @@ -917,6 +917,55 @@ def test_expressions(self, expr, subs, expected): assert uxreplace(eval(expr), eval(subs)) == eval(expected) + def test_uxreplace_reuses_empty_substitution(self): + grid = Grid(shape=(4, 4)) + f = Function(name='f', grid=grid) + expr = f.indexify() + 1 + + assert uxreplace(expr, {}) is expr + + def test_subs_if_composite_reuses_untouched_sequence(self): + grid = Grid(shape=(4, 4)) + x, y = grid.dimensions + f = Function(name='f', grid=grid) + g = Function(name='g', grid=grid) + + exprs = (Eq(f[x, y], f[x, y] + 1),) + + assert subs_if_composite(exprs, {}) is exprs + assert subs_if_composite(exprs, {g[x, y]: f[x, y]}) is exprs + assert subs_if_composite(exprs, {g[x, y] + 1: f[x, y]}) is exprs + + processed = subs_if_composite(exprs, {f[x, y]: g[x, y]}) + + assert processed is not exprs + assert processed[0] is not exprs[0] + + def test_pow_to_mul_reuses_untouched_sequence(self): + grid = Grid(shape=(4, 4)) + x, y = grid.dimensions + f = Function(name='f', grid=grid) + + exprs = (Eq(f[x, y], f[x, y] + 1),) + + assert pow_to_mul(exprs) is exprs + assert pow_to_mul([exprs[0]])[0] is exprs[0] + + processed = pow_to_mul((Eq(f[x, y], f[x, y]**2),)) + + assert processed is not exprs + + def test_xreplace_indices_reuses_untouched_sequence(self): + grid = Grid(shape=(4, 4)) + x, y = grid.dimensions + z = Dimension(name='z') + f = Function(name='f', grid=grid) + + exprs = (Eq(f[x, y], f[x, y] + 1),) + + assert xreplace_indices(exprs, {z: z + 1}) is exprs + assert xreplace_indices(exprs, {x: x + 1}) is not exprs + def test_custom_reconstructable(self): class MyDefFunction(DefFunction): From 1d97e87de6223a15bae4a209c6a4f0d65839237c Mon Sep 17 00:00:00 2001 From: Fabio Luporini Date: Wed, 29 Apr 2026 10:08:50 +0100 Subject: [PATCH 23/45] compiler: Memoize Fusion._key --- devito/passes/clusters/fusion.py | 3 ++- devito/tools/memoization.py | 10 +++++++++- tests/test_tools.py | 29 ++++++++++++++++++++++++++++- 3 files changed, 39 insertions(+), 3 deletions(-) diff --git a/devito/passes/clusters/fusion.py b/devito/passes/clusters/fusion.py index 793cddedd8..c8b6ac2cb1 100644 --- a/devito/passes/clusters/fusion.py +++ b/devito/passes/clusters/fusion.py @@ -8,7 +8,7 @@ ) from devito.symbolics import search from devito.tools import ( - DAG, as_tuple, flatten, frozendict, memoized_func, timed_pass + DAG, as_tuple, flatten, frozendict, memoized_func, memoized_meth, timed_pass ) __all__ = ['fuse'] @@ -130,6 +130,7 @@ def __new__(cls, itintervals, guards, syncs, weak): return obj + @memoized_meth def _key(self, c): itintervals = frozenset(c.ispace.itintervals) guards = c.guards if any(c.guards) else None diff --git a/devito/tools/memoization.py b/devito/tools/memoization.py index e003db45f2..1932f930a0 100644 --- a/devito/tools/memoization.py +++ b/devito/tools/memoization.py @@ -137,11 +137,19 @@ def __call__(self, *args, **kw): cache = obj.__cache_meth except AttributeError: cache = obj.__cache_meth = {} - key = (self.func, args[1:], frozenset(kw.items())) + if kw: + key = (self.func, args[1:], frozenset(kw.items())) + else: + key = (self.func, args[1:]) + try: res = cache[key] except KeyError: res = cache[key] = self.func(*args, **kw) + except TypeError: + # Uncacheable, e.g. an unhashable item within ``args``. + return self.func(*args, **kw) + return res diff --git a/tests/test_tools.py b/tests/test_tools.py index 14ccd990b9..c448a91485 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -9,7 +9,7 @@ from devito.tools import ( DefaultFrozenDict, CacheInstances, UnboundedMultiTuple, UnboundTuple, ctypes_to_cstr, filter_ordered, - toposort, transitive_closure + memoized_meth, toposort, transitive_closure ) from devito.types.basic import Symbol @@ -62,6 +62,33 @@ def test_transitive_closure(): assert mapper == {a: d, b: d, c: d, f: e} +def test_memoized_meth(): + + class Obj: + + def __init__(self): + self.calls = 0 + + @memoized_meth + def f(self, x=None): + self.calls += 1 + return x + + obj = Obj() + + assert obj.f(1) == 1 + assert obj.f(1) == 1 + assert obj.calls == 1 + + assert obj.f(x=2) == 2 + assert obj.f(x=2) == 2 + assert obj.calls == 2 + + assert obj.f([3]) == [3] + assert obj.f([3]) == [3] + assert obj.calls == 4 + + def test_default_frozen_dict(): mapper = DefaultFrozenDict({'a': 'b'}, default='c') From 57de80720a721aaafece5f299a84748af8550020 Mon Sep 17 00:00:00 2001 From: Fabio Luporini Date: Wed, 29 Apr 2026 15:59:18 +0100 Subject: [PATCH 24/45] compiler: Avoid rebuilding Nodes when possible --- devito/ir/iet/nodes.py | 89 ++++++++++++++++++++++++++++++++------- devito/ir/iet/visitors.py | 22 +--------- 2 files changed, 74 insertions(+), 37 deletions(-) diff --git a/devito/ir/iet/nodes.py b/devito/ir/iet/nodes.py index d26b97af7a..626f2bef16 100644 --- a/devito/ir/iet/nodes.py +++ b/devito/ir/iet/nodes.py @@ -6,7 +6,7 @@ from collections import OrderedDict, namedtuple from collections.abc import Iterable from contextlib import suppress -from functools import cached_property +from functools import cached_property, lru_cache import cgen as c from sympy import IndexedBase, sympify @@ -102,27 +102,36 @@ class Node(Signer): def __new__(cls, *args, **kwargs): obj = super().__new__(cls) - argnames, _, _, defaultvalues, _, _, _ = inspect.getfullargspec(cls.__init__) - try: - defaults = dict( - zip(argnames[-len(defaultvalues):], defaultvalues, strict=True) - ) - except TypeError: - # No default kwarg values - defaults = {} - obj._args = {k: v for k, v in zip(argnames[1:], args, strict=False)} + argnames, defaults = _constructor_args(cls) + obj._args = {k: v for k, v in zip(argnames, args, strict=False)} obj._args.update(kwargs.items()) - obj._args.update({k: defaults.get(k) for k in argnames[1:] if k not in obj._args}) + obj._args.update({k: defaults.get(k) for k in argnames if k not in obj._args}) return obj def _rebuild(self, *args, **kwargs): """Reconstruct ``self``.""" handle = self._args.copy() # Original constructor arguments argnames = [i for i in self._traversable if i not in kwargs] - handle.update(OrderedDict([(k, v) for k, v in zip(argnames, args, strict=False)])) - handle.update(kwargs) + updates = OrderedDict([(k, v) for k, v in zip(argnames, args, strict=False)]) + updates.update(kwargs) + + if updates and all(self._same_arg(k, v) for k, v in updates.items()): + return self + + handle.update(updates) return type(self)(**handle) + def _same_arg(self, key, value): + with suppress(AttributeError): + if _same_as_before(getattr(self, key), value): + return True + + with suppress(KeyError): + if _same_as_before(self._args[key], value): + return True + + return False + @cached_property def ccode(self): """ @@ -1558,9 +1567,6 @@ def DummyExpr(*args, init=False): return Expression(DummyEq(*args), init=init) -BlankLine = CBlankLine() - - # Nodes required for distributed-memory halo exchange @@ -1635,3 +1641,54 @@ def functions(self): Iteration/Expression tree. ``local`` is a boolean indicating whether the definition of the callable is known or not. """ + + +# *** Utils + + +@lru_cache(maxsize=None) +def _constructor_args(cls): + """ + Return cached constructor argument names and default values for an IET type. + + IET node construction records the original constructor arguments in + ``_args``. This helper avoids repeating ``inspect.getfullargspec`` for every + node instance of the same class. + """ + argnames, _, _, defaultvalues, _, _, _ = inspect.getfullargspec(cls.__init__) + if defaultvalues is None: + defaults = {} + else: + defaults = dict(zip(argnames[-len(defaultvalues):], defaultvalues, strict=True)) + + return tuple(argnames[1:]), defaults + + +def _same_as_before(old, new): + """ + Return True if ``new`` preserves the object identity structure of ``old``. + + This intentionally does not use equality for arbitrary objects. It only + recurses through common containers and otherwise requires object identity, + which keeps no-op rebuild detection compatible with IET mapper semantics. + """ + if old is new: + return True + + if isinstance(old, (tuple, list)) and isinstance(new, (tuple, list)): + return len(old) == len(new) and all( + _same_as_before(i, j) for i, j in zip(old, new, strict=True) + ) + + if type(old) is not type(new): + return False + + if isinstance(old, dict): + return old.keys() == new.keys() and all( + _same_as_before(v, new[k]) for k, v in old.items() + ) + + return False + + +BlankLine = CBlankLine() diff --git a/devito/ir/iet/visitors.py b/devito/ir/iet/visitors.py index 7959ca2186..156fabc3e1 100644 --- a/devito/ir/iet/visitors.py +++ b/devito/ir/iet/visitors.py @@ -19,7 +19,7 @@ from devito.ir.cgen.printer import get_printer from devito.ir.iet.nodes import ( BlankLine, Call, Expression, ExpressionBundle, Iteration, Lambda, ListMajor, Node, - Section + Section, _same_as_before as same_as_before ) from devito.ir.support.space import Backward from devito.symbolics import ( @@ -1527,26 +1527,6 @@ def visit_KernelLaunch(self, o): blankline = c.Line("") -def same_as_before(old, new): - if old is new: - return True - - if isinstance(old, (tuple, list)) and isinstance(new, (tuple, list)): - return len(old) == len(new) and all( - same_as_before(i, j) for i, j in zip(old, new, strict=True) - ) - - if type(old) is not type(new): - return False - - if isinstance(old, dict): - return old.keys() == new.keys() and all( - same_as_before(v, new[k]) for k, v in old.items() - ) - - return False - - def reuse_if_unchanged(o, *children, **kwargs): def same_kwarg(k, v): with suppress(AttributeError): From 3b9640b7d5cfe629e35e75c5a16920b862af6a75 Mon Sep 17 00:00:00 2001 From: Fabio Luporini Date: Wed, 29 Apr 2026 15:59:29 +0100 Subject: [PATCH 25/45] compiler: Memoize IET visitors --- devito/ir/iet/visitors.py | 12 ++++++++- devito/tools/memoization.py | 50 +++++++++++++++++++++++++++++++++++++ tests/test_tools.py | 38 +++++++++++++++++++++++++++- 3 files changed, 98 insertions(+), 2 deletions(-) diff --git a/devito/ir/iet/visitors.py b/devito/ir/iet/visitors.py index 156fabc3e1..99b0768646 100644 --- a/devito/ir/iet/visitors.py +++ b/devito/ir/iet/visitors.py @@ -28,7 +28,7 @@ from devito.symbolics.extended_dtypes import NoDeclStruct from devito.tools import ( GenericVisitor, as_tuple, c_restrict_void_p, filter_ordered, filter_sorted, flatten, - is_external_ctype, sorted_priority + is_external_ctype, memoized_weak_meth, sorted_priority ) from devito.types import ( ArrayObject, CompositeObject, DeviceMap, Dimension, IndexedData, Pointer @@ -1115,12 +1115,17 @@ def _defines_aliases(n): def __init__(self, mode: str = 'symbolics') -> None: super().__init__() + self.mode = mode modes = mode.split('|') if len(modes) == 1: self.rule = self.rules[mode] else: self.rule = lambda n: chain(*[self.rules[mode](n) for mode in modes]) + @memoized_weak_meth(key=lambda i: i.mode, freeze=tuple, thaw=list) + def visit(self, o, *args, **kwargs): + return super().visit(o, *args, **kwargs) + def _post_visit(self, ret): return sorted(filter_ordered(ret, key=id), key=str) @@ -1236,8 +1241,13 @@ class FindApplications(LazyVisitor[ApplicationType, set[ApplicationType], None]) def __init__(self, cls: type[ApplicationType] = Application): super().__init__() + self.cls = cls self.match = lambda i: isinstance(i, cls) and not isinstance(i, Basic) + @memoized_weak_meth(key=lambda i: i.cls, freeze=frozenset, thaw=set) + def visit(self, o, *args, **kwargs): + return super().visit(o, *args, **kwargs) + def _post_visit(self, ret): return set(ret) diff --git a/devito/tools/memoization.py b/devito/tools/memoization.py index 1932f930a0..cb36870e37 100644 --- a/devito/tools/memoization.py +++ b/devito/tools/memoization.py @@ -2,6 +2,7 @@ from functools import lru_cache, partial, wraps from itertools import tee from typing import TypeVar +from weakref import WeakKeyDictionary __all__ = [ 'CacheInstances', @@ -9,6 +10,7 @@ 'memoized_func', 'memoized_generator', 'memoized_meth', + 'memoized_weak_meth', 'reuse_if_unchanged' ] @@ -187,6 +189,54 @@ def __call__(self, *args, **kwargs): return result +def memoized_weak_meth(*, key=None, freeze=None, thaw=None): + """ + Cache a method result against its first argument using weak references. + + This is useful for visitors operating on temporary IR roots: the cache can + be shared across short-lived visitor instances without keeping those roots + alive. Only calls without extra arguments are cached; all other calls fall + back to the wrapped method. + + Parameters + ---------- + key : callable, optional + A callable receiving ``self`` and returning a hashable cache partition. + freeze : callable, optional + Convert the method result before storing it in the cache. + thaw : callable, optional + Convert the cached value before returning it to the caller. + """ + def decorator(func): + caches = {} + + @wraps(func) + def wrapper(self, o, *args, **kwargs): + if args or kwargs: + return func(self, o, *args, **kwargs) + + try: + partition = key(self) if key is not None else None + cache = caches.setdefault(partition, WeakKeyDictionary()) + ret = cache[o] + except KeyError: + ret = func(self, o) + if freeze is not None: + ret = freeze(ret) + cache[o] = ret + except TypeError: + return func(self, o) + + if thaw is not None: + return thaw(ret) + + return ret + + return wrapper + + return decorator + + # Describes the type of a subclass of CacheInstances InstanceType = TypeVar('InstanceType', bound='CacheInstances', covariant=True) diff --git a/tests/test_tools.py b/tests/test_tools.py index c448a91485..e1943fe802 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -9,7 +9,7 @@ from devito.tools import ( DefaultFrozenDict, CacheInstances, UnboundedMultiTuple, UnboundTuple, ctypes_to_cstr, filter_ordered, - memoized_meth, toposort, transitive_closure + memoized_meth, memoized_weak_meth, toposort, transitive_closure ) from devito.types.basic import Symbol @@ -89,6 +89,42 @@ def f(self, x=None): assert obj.calls == 4 +def test_memoized_weak_meth(): + + class Root: + pass + + class Obj: + + def __init__(self, mode): + self.mode = mode + self.calls = 0 + + @memoized_weak_meth(key=lambda i: i.mode, freeze=tuple, thaw=list) + def f(self, root): + self.calls += 1 + return [self.mode] + + root = Root() + obj0 = Obj('a') + obj1 = Obj('a') + obj2 = Obj('b') + + ret = obj0.f(root) + ret.append('mutated') + + assert obj1.f(root) == ['a'] + assert obj0.calls == 1 + assert obj1.calls == 0 + + assert obj2.f(root) == ['b'] + assert obj2.calls == 1 + + assert obj0.f([]) == ['a'] + assert obj0.f([]) == ['a'] + assert obj0.calls == 3 + + def test_default_frozen_dict(): mapper = DefaultFrozenDict({'a': 'b'}, default='c') From 307cdaf32c553ed25d0fcf763cc3ecb255aac75d Mon Sep 17 00:00:00 2001 From: Fabio Luporini Date: Thu, 30 Apr 2026 11:22:18 +0100 Subject: [PATCH 26/45] compiler: Memoize IET engine --- devito/ir/iet/visitors.py | 13 +------------ devito/passes/iet/engine.py | 24 ++++++++++++++++-------- 2 files changed, 17 insertions(+), 20 deletions(-) diff --git a/devito/ir/iet/visitors.py b/devito/ir/iet/visitors.py index 99b0768646..9e5913efd6 100644 --- a/devito/ir/iet/visitors.py +++ b/devito/ir/iet/visitors.py @@ -1538,17 +1538,6 @@ def visit_KernelLaunch(self, o): def reuse_if_unchanged(o, *children, **kwargs): - def same_kwarg(k, v): - with suppress(AttributeError): - if same_as_before(getattr(o, k), v): - return True - - with suppress(KeyError): - if same_as_before(o.args[k], v): - return True - - return False - if children: same_children = all( same_as_before(i, j) for i, j in zip(o.children, children, strict=True) @@ -1556,7 +1545,7 @@ def same_kwarg(k, v): if not same_children: return o._rebuild(*children, **kwargs) - if kwargs and not all(same_kwarg(k, v) for k, v in kwargs.items()): + if kwargs and not all(o._same_arg(k, v) for k, v in kwargs.items()): return o._rebuild(*children, **kwargs) return o diff --git a/devito/passes/iet/engine.py b/devito/passes/iet/engine.py index 0e92de1ec2..85217d6ed6 100644 --- a/devito/passes/iet/engine.py +++ b/devito/passes/iet/engine.py @@ -15,7 +15,10 @@ from devito.mpi.routines import Gather, HaloUpdate, HaloWait, MPIMsg, Scatter from devito.passes import needs_transfer from devito.symbolics import FieldFromComposite, FieldFromPointer, IndexedPointer, search -from devito.tools import DAG, as_tuple, filter_ordered, sorted_priority, timed_pass +from devito.tools import ( + DAG, as_hashable, as_tuple, filter_ordered, memoized_func, sorted_priority, + timed_pass +) from devito.types import ( Array, Auto, Bundle, ComponentAccess, CompositeObject, FunctionMap, IncrDimension, Indirection, ModuloDimension, NPThreads, NThreadsBase, Pointer, SharedData, Symbol, @@ -102,7 +105,7 @@ def sync_mapper(self): A mapper {Iteration -> SyncSpot} describing the Iterations, if any, living an asynchronous region, across all Callables in the Graph. """ - dag = create_call_graph(self.root.name, self.efuncs) + dag = create_call_graph(self.root.name, as_hashable(self.efuncs)) mapper = MapNodes(SyncSpot, (Iteration, Call)).visit(self.root) @@ -129,7 +132,7 @@ def apply(self, func, *, updates_args=True, **kwargs): """ Apply `func` to all nodes in the Graph. This changes the state of the Graph. """ - dag = create_call_graph(self.root.name, self.efuncs) + dag = create_call_graph(self.root.name, as_hashable(self.efuncs)) # Apply `func` efuncs = dict(self.efuncs) @@ -184,7 +187,7 @@ def visit(self, func, **kwargs): from nodes to info. Unlike `apply`, `visit` does not change the state of the Graph. """ - dag = create_call_graph(self.root.name, self.efuncs) + dag = create_call_graph(self.root.name, as_hashable(self.efuncs)) toposort = dag.topological_sort() mapper = dict([(i, func(self.efuncs[i], **kwargs)) for i in toposort]) @@ -242,11 +245,14 @@ def iet_visit(func): return iet_pass((iet_visit, func)) +@memoized_func(scope='build') def create_call_graph(root, efuncs): """ Create a Call graph -- a Direct Acyclic Graph with edges from callees to callers. """ + efuncs = dict(efuncs) + dag = DAG(nodes=[root]) queue = [root] @@ -442,7 +448,7 @@ def reuse_efuncs(root, efuncs, sregistry=None): # assuming that `bar0` and `bar1` are compatible, we first process the # `bar`'s to obtain `[foo0(u(x)): bar0(u), foo1(u(x)): bar0(u)]`, # and finally `foo0(u(x)): bar0(u)` - dag = create_call_graph(root.name, efuncs) + dag = create_call_graph(root.name, as_hashable(efuncs)) mapper = {} for i in dag.topological_sort(): @@ -484,6 +490,7 @@ def reuse_efuncs(root, efuncs, sregistry=None): return retval +@memoized_func(scope='build') def abstract_efunc(efunc): """ Abstract `efunc` applying a set of rules: @@ -496,7 +503,7 @@ def abstract_efunc(efunc): """ functions = FindSymbols('basics|symbolics|dimensions').visit(efunc) - mapper = abstract_objects(functions) + mapper = abstract_objects(tuple(functions)) efunc = Uxreplace(mapper).visit(efunc) efunc = efunc._rebuild(name='foo') @@ -504,7 +511,8 @@ def abstract_efunc(efunc): return efunc -def abstract_objects(objects0, sregistry=None): +@memoized_func(scope='build') +def abstract_objects(objects0): """ Proxy for `abstract_object`. """ @@ -523,7 +531,7 @@ def abstract_objects(objects0, sregistry=None): # Build abstraction mappings mapper = {} - sregistry = sregistry or SymbolRegistry() + sregistry = SymbolRegistry() for i in objects: abstract_object(i, mapper, sregistry) From c885f3f041d383203d5a596c2ff4c4964f80421d Mon Sep 17 00:00:00 2001 From: Fabio Luporini Date: Thu, 30 Apr 2026 14:09:52 +0100 Subject: [PATCH 27/45] compiler: Avoid reconstructions in IET visitors --- devito/ir/iet/visitors.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/devito/ir/iet/visitors.py b/devito/ir/iet/visitors.py index 9e5913efd6..3fd340c953 100644 --- a/devito/ir/iet/visitors.py +++ b/devito/ir/iet/visitors.py @@ -1331,6 +1331,13 @@ def __init__(self, mapper, nested=False): self.mapper = mapper self.nested = nested + def visit(self, o, *args, **kwargs): + # Subclasses may implement mapper-independent transformations. + if type(self) is Transformer and not self.mapper: + return o + + return super().visit(o, *args, **kwargs) + def transform(self, o, handle, **kwargs): if handle is None: # None -> drop `o` @@ -1391,6 +1398,12 @@ class Uxreplace(Transformer): The substitution rules. """ + def visit(self, o, *args, **kwargs): + if not self.mapper: + return o + + return super().visit(o, *args, **kwargs) + def visit_Expression(self, o): return reuse_if_unchanged(o, expr=uxreplace(o.expr, self.mapper)) From e287b7eccea0dd63b506a2ca2502a09daa8abf32 Mon Sep 17 00:00:00 2001 From: Fabio Luporini Date: Thu, 30 Apr 2026 14:46:33 +0100 Subject: [PATCH 28/45] compiler: Memoize FindNodes --- devito/ir/iet/visitors.py | 10 ++++++++++ tests/test_visitors.py | 13 +++++++++++-- 2 files changed, 21 insertions(+), 2 deletions(-) diff --git a/devito/ir/iet/visitors.py b/devito/ir/iet/visitors.py index 3fd340c953..082ef85898 100644 --- a/devito/ir/iet/visitors.py +++ b/devito/ir/iet/visitors.py @@ -1172,8 +1172,13 @@ class FindNodes(LazyVisitor[Node, list[Node], None]): def __init__(self, match: type, mode: str = 'type') -> None: super().__init__() self.match = match + self.mode = mode self.rule = self.rules[mode] + @memoized_weak_meth(key=lambda i: (i.match, i.mode), freeze=tuple, thaw=list) + def visit(self, o, *args, **kwargs): + return super().visit(o, *args, **kwargs) + def visit_Node(self, o: Node, **kwargs) -> Iterator[Node]: if self.rule(self.match, o): yield o @@ -1194,6 +1199,11 @@ def __init__(self, match: type, start: Node, stop: Node | None = None) -> None: self.start = start self.stop = stop + def visit(self, o, *args, **kwargs): + # `start` and `stop` are part of this visitor's state. + return GenericVisitor.visit(self, o, *args, **kwargs) + + def visit_object(self, o: object, flag: bool = False) -> LazyVisit[Node, bool]: yield from () return flag # noqa: B901 diff --git a/tests/test_visitors.py b/tests/test_visitors.py index 7acd7bbf86..b5d12f81d5 100644 --- a/tests/test_visitors.py +++ b/tests/test_visitors.py @@ -6,8 +6,8 @@ from devito.ir.equations import DummyEq from devito.ir.iet import ( Block, Call, Callable, Conditional, Expression, FindApplications, FindNodes, - FindSections, FindSymbols, IsPerfectIteration, Iteration, MapNodes, Transformer, - Uxreplace, printAST + FindSections, FindSymbols, FindWithin, IsPerfectIteration, Iteration, MapNodes, + Transformer, Uxreplace, printAST ) from devito.types import Array, SpaceDimension, Symbol @@ -210,6 +210,15 @@ def test_find_sections(exprs, block1, block2, block3): assert len(found[2]) == 1 +def test_find_within_not_cached_like_findnodes(block3): + expr0 = FindWithin(Expression, block3.nodes[0], block3.nodes[1]).visit(block3) + expr1 = FindWithin(Expression, block3.nodes[1], block3.nodes[2]).visit(block3) + + assert len(expr0) == 3 + assert len(expr1) == 3 + assert expr0 != expr1 + + def test_is_perfect_iteration(block1, block2, block3, block4): checker = IsPerfectIteration() From 388a587a2e75f1a546ee9eba58f6296651212d05 Mon Sep 17 00:00:00 2001 From: Fabio Luporini Date: Tue, 12 May 2026 11:19:17 +0100 Subject: [PATCH 29/45] compiler: Call finalize_args once at the end of the lowering --- devito/operator/operator.py | 7 ++-- devito/passes/iet/definitions.py | 6 ++-- devito/passes/iet/engine.py | 58 +++++++++++++++++++++++++----- devito/passes/iet/instrument.py | 4 +-- devito/passes/iet/linearization.py | 2 +- devito/passes/iet/misc.py | 6 ++-- tests/test_iet.py | 8 ++--- 7 files changed, 67 insertions(+), 24 deletions(-) diff --git a/devito/operator/operator.py b/devito/operator/operator.py index 8a7ceaf259..92997653da 100644 --- a/devito/operator/operator.py +++ b/devito/operator/operator.py @@ -32,8 +32,8 @@ from devito.operator.registry import operator_selector from devito.parameters import configuration from devito.passes import ( - Graph, error_mapper, generate_implicit, generate_macros, is_on_device, lower_dtypes, - lower_index_derivatives, minimize_symbols, optimize_pows, unevaluate + Graph, error_mapper, finalize_args, generate_implicit, generate_macros, is_on_device, + lower_dtypes, lower_index_derivatives, minimize_symbols, optimize_pows, unevaluate ) from devito.symbolics import estimate_cost, subs_op_args from devito.tools import ( @@ -522,6 +522,9 @@ def _lower_iet(cls, uiet, **kwargs): # Target-independent optimizations minimize_symbols(graph) + # Finalize helper signatures after all IET transformations have settled. + finalize_args(graph) + return graph.root, graph # Read-only properties exposed to the outside world diff --git a/devito/passes/iet/definitions.py b/devito/passes/iet/definitions.py index bef8f0f1f4..9fc20bfcee 100644 --- a/devito/passes/iet/definitions.py +++ b/devito/passes/iet/definitions.py @@ -444,7 +444,7 @@ def _inject_definitions(self, iet, storage): return processed, flatten(efuncs) - @iet_pass + @iet_pass(updates_args=True) def place_definitions(self, iet, globs=None, **kwargs): """ Create a new IET where all symbols have been declared, allocated, and @@ -518,7 +518,7 @@ def place_definitions(self, iet, globs=None, **kwargs): 'globals': as_tuple(globs), 'includes': as_tuple(sorted(storage.includes))} - @iet_pass(updates_args=False) + @iet_pass def place_casts(self, iet, **kwargs): """ Create a new IET with the necessary type casts. @@ -669,7 +669,7 @@ def place_transfers(self, iet, data_movs=None, ctx=None, **kwargs): return iet, {'efuncs': efuncs} - @iet_pass(updates_args=False) + @iet_pass def place_devptr(self, iet, **kwargs): """ Transform `iet` such that device pointers are used in DeviceCalls. diff --git a/devito/passes/iet/engine.py b/devito/passes/iet/engine.py index 85217d6ed6..3655e4a413 100644 --- a/devito/passes/iet/engine.py +++ b/devito/passes/iet/engine.py @@ -28,7 +28,7 @@ from devito.types.dense import DiscreteFunction from devito.types.dimension import AbstractIncrDimension, BlockDimension -__all__ = ['Graph', 'iet_pass', 'iet_visit'] +__all__ = ['Graph', 'finalize_args', 'iet_pass', 'iet_visit'] class Byproduct: @@ -128,10 +128,22 @@ def sync_mapper(self): return found - def apply(self, func, *, updates_args=True, **kwargs): + def apply(self, func, *, updates_args=False, **kwargs): """ - Apply `func` to all nodes in the Graph. This changes the state of the Graph. + Apply ``func`` to all nodes in the Graph. + + Parameters + ---------- + updates_args : bool, optional + If True, reconcile Callable parameters and Call arguments before + the graph walk and after each changed node. This is only needed by + passes whose transformation logic depends on already-updated + signatures while the pass is still running. Otherwise, argument + reconciliation is intentionally deferred to ``finalize_args``. """ + if updates_args: + _update_args(self) + dag = create_call_graph(self.root.name, as_hashable(self.efuncs)) # Apply `func` @@ -161,10 +173,8 @@ def apply(self, func, *, updates_args=True, **kwargs): efuncs[i] = efunc efuncs.update(dict([(i.name, i) for i in new_efuncs])) - # Update the parameters / arguments lists if the pass may have - # introduced or removed objects. if updates_args: - efuncs = update_args(efunc, efuncs, dag) + efuncs = _update_args_efunc(efunc, efuncs, dag) # Minimize code size if len(efuncs) > len(self.efuncs): @@ -209,7 +219,37 @@ def filter(self, key): ) -def iet_pass(func=None, *, updates_args=True): +@timed_pass(name='finalize_args') +def finalize_args(graph): + """ + Finalize Callable parameter lists and Call argument lists across ``graph``. + + IET passes may temporarily leave helper signatures stale while introducing + or eliminating symbols. This pass reconciles the whole call graph once, + after lowering has settled. + """ + _update_args(graph) + + +def _update_args(graph): + dag = create_call_graph(graph.root.name, as_hashable(graph.efuncs)) + + efuncs = graph.efuncs + for i in dag.topological_sort(): + efuncs = _update_args_efunc(efuncs[i], efuncs, dag) + + graph.efuncs = efuncs + + +def iet_pass(func=None, *, updates_args=False): + """ + Decorate an IET pass. + + ``updates_args=True`` is an opt-in for passes that must observe up-to-date + Callable/Call signatures before and during their own graph walk. Most + passes should leave it False and rely on ``finalize_args`` at the end of + IET lowering. + """ if func is None: return partial(iet_pass, updates_args=updates_args) @@ -702,7 +742,7 @@ def _(i, mapper, sregistry): mapper[i] = i._rebuild(name=sregistry.make_name(prefix='nthreads')) -def update_args(root, efuncs, dag): +def _update_args_efunc(root, efuncs, dag): """ Re-derive the parameters of `root` and apply the changes in cascade through the `efuncs`. @@ -800,6 +840,6 @@ def _filter(v, efunc=None): continue efuncs[n] = efunc - efuncs = update_args(efunc, efuncs, dag) + efuncs = _update_args_efunc(efunc, efuncs, dag) return efuncs diff --git a/devito/passes/iet/instrument.py b/devito/passes/iet/instrument.py index 7251683f0f..9bd1ce2134 100644 --- a/devito/passes/iet/instrument.py +++ b/devito/passes/iet/instrument.py @@ -27,7 +27,7 @@ def instrument(graph, **kwargs): sync_sections(graph, **kwargs) -@iet_pass(updates_args=False) +@iet_pass def track_subsections(iet, **kwargs): """ Add sub-Sections to the `profiler`. Sub-Sections include: @@ -122,7 +122,7 @@ def instrument_sections(iet, **kwargs): return piet, {'headers': headers} -@iet_pass(updates_args=False) +@iet_pass def sync_sections(iet, langbb=None, profiler=None, **kwargs): """ Wrap sections within global barriers if deemed necessary by the profiler. diff --git a/devito/passes/iet/linearization.py b/devito/passes/iet/linearization.py index aca2485444..78ec84987c 100644 --- a/devito/passes/iet/linearization.py +++ b/devito/passes/iet/linearization.py @@ -46,7 +46,7 @@ def linearize(graph, **kwargs): linearization(graph, key=key, tracker=tracker, **kwargs) -@iet_pass +@iet_pass(updates_args=True) def linearization(iet, key=None, tracker=None, **kwargs): """ Carry out the actual work of `linearize`. diff --git a/devito/passes/iet/misc.py b/devito/passes/iet/misc.py index 1d2324e3d1..0a631bf3a2 100644 --- a/devito/passes/iet/misc.py +++ b/devito/passes/iet/misc.py @@ -29,7 +29,7 @@ ] -@iet_pass(updates_args=False) +@iet_pass def avoid_denormals(iet, platform=None, **kwargs): """ Introduce nodes in the Iteration/Expression tree that will expand to C @@ -60,7 +60,7 @@ def avoid_denormals(iet, platform=None, **kwargs): return iet, {'includes': ('xmmintrin.h', 'pmmintrin.h')} -@iet_pass(updates_args=False) +@iet_pass def hoist_prodders(iet): """ Move Prodders within the outer levels of an Iteration tree. @@ -151,7 +151,7 @@ def generate_macros(graph, **kwargs): _generate_macros(graph, tracker={}, **kwargs) -@iet_pass(updates_args=False) +@iet_pass def _generate_macros(iet, tracker=None, langbb=None, printer=CPrinter, **kwargs): # Derive the Macros necessary for the FIndexeds iet = _generate_macros_findexeds(iet, tracker=tracker, **kwargs) diff --git a/tests/test_iet.py b/tests/test_iet.py index 129c895a3e..0c3f944a32 100644 --- a/tests/test_iet.py +++ b/tests/test_iet.py @@ -540,20 +540,20 @@ def test_complex_array(): "float _Complex **restrict a_vec __attribute__ ((aligned (64)));" -def test_iet_pass_skip_update_args(monkeypatch): +def test_iet_pass_does_not_update_args(monkeypatch): x = Symbol(name='x') y = Symbol(name='y') foo = Callable('foo', DummyExpr(x, y), 'void', parameters=(x, y)) graph = Graph(foo) - @iet_pass(updates_args=False) + @iet_pass def inject_expr(iet): body = iet.body._rebuild(body=iet.body.body + (DummyExpr(x, x),)) return iet._rebuild(body=body), {} - monkeypatch.setattr(iet_engine, 'update_args', - lambda *args, **kwargs: pytest.fail("update_args called")) + monkeypatch.setattr(iet_engine, '_update_args', + lambda *args, **kwargs: pytest.fail("_update_args called")) inject_expr(graph) From 21b93023c0a54948af01f82ed10e6704f32d20ae Mon Sep 17 00:00:00 2001 From: Fabio Luporini Date: Tue, 16 Jun 2026 11:51:35 +0100 Subject: [PATCH 30/45] compiler: Rename _rebuild -> _reuse_if_untouched --- devito/ir/support/guards.py | 14 +++++++------- devito/ir/support/properties.py | 18 +++++++++--------- 2 files changed, 16 insertions(+), 16 deletions(-) diff --git a/devito/ir/support/guards.py b/devito/ir/support/guards.py index cd9a5292b2..697ddb0f8e 100644 --- a/devito/ir/support/guards.py +++ b/devito/ir/support/guards.py @@ -272,7 +272,7 @@ class Guards(frozendict): def get(self, d, v=true): return super().get(d, v) - def _rebuild(self, mapper): + def _reuse_if_untouched(self, mapper): return self if mapper == self else Guards(mapper) def andg(self, d, guard): @@ -286,7 +286,7 @@ def andg(self, d, guard): except KeyError: m[d] = guard - return self._rebuild(m) + return self._reuse_if_untouched(m) def xandg(self, d, guard): m = dict(self) @@ -299,7 +299,7 @@ def xandg(self, d, guard): except KeyError: m[d] = guard - return self._rebuild(m) + return self._reuse_if_untouched(m) def pairwise_or(self, d, *guards): m = dict(self) @@ -314,7 +314,7 @@ def pairwise_or(self, d, *guards): else: m[d] = g - return self._rebuild(m) + return self._reuse_if_untouched(m) def impose(self, d, guard): m = dict(self) @@ -324,7 +324,7 @@ def impose(self, d, guard): m[d] = guard - return self._rebuild(m) + return self._reuse_if_untouched(m) def popany(self, dims): m = dict(self) @@ -332,12 +332,12 @@ def popany(self, dims): for d in as_tuple(dims): m.pop(d, None) - return self._rebuild(m) + return self._reuse_if_untouched(m) def filter(self, key): m = {d: v for d, v in self.items() if key(d)} - return self._rebuild(m) + return self._reuse_if_untouched(m) def as_map(self, d, cls): if cls not in (Le, Lt, Ge, Gt): diff --git a/devito/ir/support/properties.py b/devito/ir/support/properties.py index a835bb3f07..6827e7c7bf 100644 --- a/devito/ir/support/properties.py +++ b/devito/ir/support/properties.py @@ -208,18 +208,18 @@ def __init__(self, *args, **kwargs): def dimensions(self): return tuple(self) - def _rebuild(self, mapper): + def _reuse_if_untouched(self, mapper): return self if mapper == self else Properties(mapper) def add(self, dims, properties=None): m = dict(self) for d in as_tuple(dims): m[d] = set(self.get(d, [])) | set(as_tuple(properties)) - return self._rebuild(m) + return self._reuse_if_untouched(m) def filter(self, key): m = {d: v for d, v in self.items() if key(d)} - return self._rebuild(m) + return self._reuse_if_untouched(m) def drop(self, dims=None, properties=None): if dims is None: @@ -230,7 +230,7 @@ def drop(self, dims=None, properties=None): m.pop(d, None) else: m[d] = self[d] - set(as_tuple(properties)) - return self._rebuild(m) + return self._reuse_if_untouched(m) def parallelize(self, dims): m = dict(self) @@ -239,13 +239,13 @@ def parallelize(self, dims): v.difference_update({PARALLEL_IF_PVT, PARALLEL_IF_ATOMIC, SEQUENTIAL}) v.add(PARALLEL) m[d] = v - return self._rebuild(m) + return self._reuse_if_untouched(m) def affine(self, dims): m = dict(self) for d in as_tuple(dims): m[d] = set(self.get(d, [])) | {AFFINE} - return self._rebuild(m) + return self._reuse_if_untouched(m) def sequentialize(self, dims=None): if dims is None: @@ -253,13 +253,13 @@ def sequentialize(self, dims=None): m = dict(self) for d in as_tuple(dims): m[d] = normalize_properties(set(self.get(d, [])), {SEQUENTIAL}) - return self._rebuild(m) + return self._reuse_if_untouched(m) def prefetchable(self, dims, v=PREFETCHABLE): m = dict(self) for d in as_tuple(dims): m[d] = self.get(d, set()) | {v} - return self._rebuild(m) + return self._reuse_if_untouched(m) def block(self, dims, kind='default'): if kind == 'default': @@ -271,7 +271,7 @@ def block(self, dims, kind='default'): m = dict(self) for d in as_tuple(dims): m[d] = set(self.get(d, [])) | {p} - return self._rebuild(m) + return self._reuse_if_untouched(m) def inbound(self, dims): return self.add(dims, INBOUND) From 644b24e96746d976a1a6133980359d911287c691 Mon Sep 17 00:00:00 2001 From: Fabio Luporini Date: Tue, 16 Jun 2026 11:54:27 +0100 Subject: [PATCH 31/45] compiler: Avoid useless renaming --- devito/ir/iet/visitors.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/devito/ir/iet/visitors.py b/devito/ir/iet/visitors.py index 082ef85898..2b674426bf 100644 --- a/devito/ir/iet/visitors.py +++ b/devito/ir/iet/visitors.py @@ -19,7 +19,7 @@ from devito.ir.cgen.printer import get_printer from devito.ir.iet.nodes import ( BlankLine, Call, Expression, ExpressionBundle, Iteration, Lambda, ListMajor, Node, - Section, _same_as_before as same_as_before + Section, _same_as_before ) from devito.ir.support.space import Backward from devito.symbolics import ( @@ -1377,7 +1377,7 @@ def visit_tuple(self, o, **kwargs): visited = tuple(self._visit(i, **kwargs) for i in o) processed = tuple(i for i in visited if i is not None) - if same_as_before(o, processed): + if _same_as_before(o, processed): return o return processed @@ -1563,7 +1563,7 @@ def visit_KernelLaunch(self, o): def reuse_if_unchanged(o, *children, **kwargs): if children: same_children = all( - same_as_before(i, j) for i, j in zip(o.children, children, strict=True) + _same_as_before(i, j) for i, j in zip(o.children, children, strict=True) ) if not same_children: return o._rebuild(*children, **kwargs) From 76ebc86aec7f37f04762473d33e279f24cdf709b Mon Sep 17 00:00:00 2001 From: Fabio Luporini Date: Tue, 16 Jun 2026 11:59:28 +0100 Subject: [PATCH 32/45] compiler: Refactor IREq --- devito/ir/equations/equation.py | 12 ++---------- 1 file changed, 2 insertions(+), 10 deletions(-) diff --git a/devito/ir/equations/equation.py b/devito/ir/equations/equation.py index 52ac36473f..f74fb65c84 100644 --- a/devito/ir/equations/equation.py +++ b/devito/ir/equations/equation.py @@ -87,7 +87,7 @@ def is_Increment(self): return self.operation is OpInc @cached_property - def _writes(self): + def writes(self): from devito.symbolics.queries import q_routine terminals = set(retrieve_accesses(self.lhs)) @@ -98,10 +98,6 @@ def _writes(self): return tuple(terminals) - @property - def writes(self): - return self._writes - @cached_property def reads_explicit(self): terminals = set(retrieve_accesses(self.rhs, deep=True)) @@ -119,12 +115,8 @@ def reads_conditional(self): return tuple(accesses) @cached_property - def _reads(self): - return tuple(set(self.reads_explicit) | set(self.reads_conditional)) - - @property def reads(self): - return self._reads + return tuple(set(self.reads_explicit) | set(self.reads_conditional)) @cached_property def _read_functions(self): From eb7014d8410a607a3331bc804c1ee5b53c45712b Mon Sep 17 00:00:00 2001 From: Fabio Luporini Date: Tue, 16 Jun 2026 12:10:10 +0100 Subject: [PATCH 33/45] compiler: isort happiness --- devito/ir/support/basic.py | 3 +-- devito/ir/support/space.py | 4 ++-- devito/operator/operator.py | 4 ++-- devito/passes/clusters/misc.py | 4 +--- devito/passes/iet/engine.py | 3 +-- tests/test_ir.py | 2 +- tests/test_tools.py | 6 +++--- 7 files changed, 11 insertions(+), 15 deletions(-) diff --git a/devito/ir/support/basic.py b/devito/ir/support/basic.py index d9cebed3fe..ba15b3dcf0 100644 --- a/devito/ir/support/basic.py +++ b/devito/ir/support/basic.py @@ -10,8 +10,7 @@ from devito.ir.support.utils import AccessMode, extrema from devito.ir.support.vector import LabeledVector, Vector from devito.symbolics import ( - compare_ops, q_affine, q_comp_acc, q_constant, retrieve_accesses, - retrieve_indexed + compare_ops, q_affine, q_comp_acc, q_constant, retrieve_accesses, retrieve_indexed ) from devito.tools import ( CacheInstances, Tag, as_mapper, as_tuple, cached_hash, filter_sorted, flatten, diff --git a/devito/ir/support/space.py b/devito/ir/support/space.py index 7c9f970108..43b85516fb 100644 --- a/devito/ir/support/space.py +++ b/devito/ir/support/space.py @@ -8,8 +8,8 @@ from devito.ir.support.utils import maximum, minimum from devito.ir.support.vector import Vector, vmax, vmin from devito.tools import ( - CacheInstances, Ordering, Stamp, as_list, as_set, as_tuple, filter_ordered, - cached_hash, flatten, frozendict, is_integer, toposort + CacheInstances, Ordering, Stamp, as_list, as_set, as_tuple, cached_hash, + filter_ordered, flatten, frozendict, is_integer, toposort ) from devito.types import Dimension, ModuloDimension diff --git a/devito/operator/operator.py b/devito/operator/operator.py index 92997653da..a57ce5bd04 100644 --- a/devito/operator/operator.py +++ b/devito/operator/operator.py @@ -38,8 +38,8 @@ from devito.symbolics import estimate_cost, subs_op_args from devito.tools import ( DAG, CacheInstances, MemoryEstimate, OrderedSet, ReducerMap, Signer, as_mapper, - as_tuple, contains_val, filter_sorted, flatten, frozendict, is_integer, - memoized_func, split, timed_pass, timed_region + as_tuple, contains_val, filter_sorted, flatten, frozendict, is_integer, memoized_func, + split, timed_pass, timed_region ) from devito.types import Buffer, Evaluable, device_layer, disk_layer, host_layer from devito.types.dimension import Thickness diff --git a/devito/passes/clusters/misc.py b/devito/passes/clusters/misc.py index 0cd255ce97..68b982eedc 100644 --- a/devito/passes/clusters/misc.py +++ b/devito/passes/clusters/misc.py @@ -1,9 +1,7 @@ from itertools import groupby, product from devito.ir.clusters import Queue, cluster_pass -from devito.ir.support import ( - SEPARABLE, SEQUENTIAL, Scope -) +from devito.ir.support import SEPARABLE, SEQUENTIAL, Scope from devito.passes.clusters.utils import in_critical_region from devito.symbolics import pow_to_mul from devito.tools import Stamp, flatten, frozendict, timed_pass diff --git a/devito/passes/iet/engine.py b/devito/passes/iet/engine.py index 3655e4a413..617b0ff936 100644 --- a/devito/passes/iet/engine.py +++ b/devito/passes/iet/engine.py @@ -16,8 +16,7 @@ from devito.passes import needs_transfer from devito.symbolics import FieldFromComposite, FieldFromPointer, IndexedPointer, search from devito.tools import ( - DAG, as_hashable, as_tuple, filter_ordered, memoized_func, sorted_priority, - timed_pass + DAG, as_hashable, as_tuple, filter_ordered, memoized_func, sorted_priority, timed_pass ) from devito.types import ( Array, Auto, Bundle, ComponentAccess, CompositeObject, FunctionMap, IncrDimension, diff --git a/tests/test_ir.py b/tests/test_ir.py index 7611d39dfc..a805fb01cf 100644 --- a/tests/test_ir.py +++ b/tests/test_ir.py @@ -7,8 +7,8 @@ Constant, Dimension, Eq, Function, Grid, Inc, Operator, SubDimension, TimeFunction, switchconfig ) -from devito.ir.clusters import Cluster, ClusterGroup from devito.ir.cgen import ccode +from devito.ir.clusters import Cluster, ClusterGroup from devito.ir.equations import LoweredEq from devito.ir.equations.algorithms import dimension_sort from devito.ir.iet import FindNodes, Iteration diff --git a/tests/test_tools.py b/tests/test_tools.py index e1943fe802..a4dc6c7e68 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -7,9 +7,8 @@ from devito import Eq, Operator, switchenv from devito.tools import ( - DefaultFrozenDict, - CacheInstances, UnboundedMultiTuple, UnboundTuple, ctypes_to_cstr, filter_ordered, - memoized_meth, memoized_weak_meth, toposort, transitive_closure + CacheInstances, DefaultFrozenDict, UnboundedMultiTuple, UnboundTuple, ctypes_to_cstr, + filter_ordered, memoized_meth, memoized_weak_meth, toposort, transitive_closure ) from devito.types.basic import Symbol @@ -325,6 +324,7 @@ def __init__(self, left: int, right: int): assert obj0.value == (1, 2) assert obj0 is not obj1 + def test_switchenv(): # Save previous environment previous_environ = dict(os.environ) From 7c2dc7ade5914b5f1102b946a9f28e0275948694 Mon Sep 17 00:00:00 2001 From: Fabio Luporini Date: Tue, 16 Jun 2026 12:13:09 +0100 Subject: [PATCH 34/45] compiler: pep8 happiness --- devito/ir/equations/equation.py | 8 ++------ devito/ir/iet/visitors.py | 2 -- devito/ir/support/basic.py | 3 ++- devito/ir/support/space.py | 1 + devito/tools/memoization.py | 3 ++- 5 files changed, 7 insertions(+), 10 deletions(-) diff --git a/devito/ir/equations/equation.py b/devito/ir/equations/equation.py index f74fb65c84..73d0065a58 100644 --- a/devito/ir/equations/equation.py +++ b/devito/ir/equations/equation.py @@ -33,8 +33,8 @@ class IREq(sympy.Eq, Pickable): __rkwargs__ = ('ispace', 'conditionals', 'implicit_dims', 'operation') def _hashable_content(self): - return (*super()._hashable_content(), - *tuple(getattr(self, i) for i in self.__rkwargs__)) + return (super()._hashable_content() + + tuple(as_hashable(getattr(self, i)) for i in self.__rkwargs__)) @property def is_Scalar(self): @@ -74,10 +74,6 @@ def state(self): def operation(self): return self._operation - def _hashable_content(self): - return (super()._hashable_content() + - tuple(as_hashable(getattr(self, i)) for i in self.__rkwargs__)) - @property def is_Reduction(self): return self.operation in (OpInc, OpMin, OpMax, OpMinMax) diff --git a/devito/ir/iet/visitors.py b/devito/ir/iet/visitors.py index 2b674426bf..69468f8256 100644 --- a/devito/ir/iet/visitors.py +++ b/devito/ir/iet/visitors.py @@ -7,7 +7,6 @@ import ctypes from collections import OrderedDict from collections.abc import Callable, Generator, Iterable, Iterator, Sequence -from contextlib import suppress from itertools import chain, groupby from typing import Any, Generic, TypeVar @@ -1203,7 +1202,6 @@ def visit(self, o, *args, **kwargs): # `start` and `stop` are part of this visitor's state. return GenericVisitor.visit(self, o, *args, **kwargs) - def visit_object(self, o: object, flag: bool = False) -> LazyVisit[Node, bool]: yield from () return flag # noqa: B901 diff --git a/devito/ir/support/basic.py b/devito/ir/support/basic.py index ba15b3dcf0..e87b8fe826 100644 --- a/devito/ir/support/basic.py +++ b/devito/ir/support/basic.py @@ -10,7 +10,7 @@ from devito.ir.support.utils import AccessMode, extrema from devito.ir.support.vector import LabeledVector, Vector from devito.symbolics import ( - compare_ops, q_affine, q_comp_acc, q_constant, retrieve_accesses, retrieve_indexed + compare_ops, q_affine, q_comp_acc, q_constant, retrieve_indexed ) from devito.tools import ( CacheInstances, Tag, as_mapper, as_tuple, cached_hash, filter_sorted, flatten, @@ -1496,6 +1496,7 @@ def is_regular(self): def vinf(entries): return Vector(*(entries + [S.Infinity])) + def disjoint_test(e0, e1, d, it): """ A rudimentary test to check if two accesses `e0` and `e1` along `d` within diff --git a/devito/ir/support/space.py b/devito/ir/support/space.py index 43b85516fb..4e95ebc4cb 100644 --- a/devito/ir/support/space.py +++ b/devito/ir/support/space.py @@ -799,6 +799,7 @@ def reset(self): return DataSpace(intervals, parts) + class IterationSpace(Space, CacheInstances): """ diff --git a/devito/tools/memoization.py b/devito/tools/memoization.py index cb36870e37..ce9b4a5ba8 100644 --- a/devito/tools/memoization.py +++ b/devito/tools/memoization.py @@ -324,7 +324,8 @@ def wrapper(cls, *args, **kwargs): if type(input_obj) is cls: names = getattr(cls, fields) if isinstance(fields, str) else fields for name in names: - if name in kwargs and kwargs[name] is not getattr(input_obj, name, None): + if name in kwargs and \ + kwargs[name] is not getattr(input_obj, name, None): break else: return input_obj From 78fb108363e0e9932acb34d9ff575de97769f256 Mon Sep 17 00:00:00 2001 From: Fabio Luporini Date: Tue, 16 Jun 2026 13:32:41 +0100 Subject: [PATCH 35/45] compiler: Comply with ruff --- devito/ir/iet/nodes.py | 4 ++-- devito/ir/support/basic.py | 2 +- devito/symbolics/manipulation.py | 9 ++++++--- devito/symbolics/search.py | 2 +- 4 files changed, 10 insertions(+), 7 deletions(-) diff --git a/devito/ir/iet/nodes.py b/devito/ir/iet/nodes.py index 626f2bef16..583106afdc 100644 --- a/devito/ir/iet/nodes.py +++ b/devito/ir/iet/nodes.py @@ -6,7 +6,7 @@ from collections import OrderedDict, namedtuple from collections.abc import Iterable from contextlib import suppress -from functools import cached_property, lru_cache +from functools import cache, cached_property import cgen as c from sympy import IndexedBase, sympify @@ -1646,7 +1646,7 @@ def functions(self): # *** Utils -@lru_cache(maxsize=None) +@cache def _constructor_args(cls): """ Return cached constructor argument names and default values for an IET type. diff --git a/devito/ir/support/basic.py b/devito/ir/support/basic.py index e87b8fe826..1c1031d3aa 100644 --- a/devito/ir/support/basic.py +++ b/devito/ir/support/basic.py @@ -1277,7 +1277,7 @@ def d_output_gen(self, writes=None): If ``writes`` is provided, restrict the analysis to those Functions. """ - for k, v in writes.items(): + for v in writes.values(): for w1 in v: for w2 in v: if any(not rule(w2, w1) for rule in self.rules): diff --git a/devito/symbolics/manipulation.py b/devito/symbolics/manipulation.py index 85ddfcaeab..41e3669516 100644 --- a/devito/symbolics/manipulation.py +++ b/devito/symbolics/manipulation.py @@ -331,7 +331,8 @@ def xreplace_indices(exprs, mapper, key=None): replaced = [uxreplace(i, mapper) for i in exprs0] if isinstance(exprs, Iterable): - if len(replaced) == len(exprs0) and all(i is j for i, j in zip(replaced, exprs0)): + if len(replaced) == len(exprs0) and \ + all(i is j for i, j in zip(replaced, exprs0, strict=True)): return exprs return replaced else: @@ -462,11 +463,13 @@ def reuse_if_untouched(expr, args, evaluate=False): args = tuple(args) if type(expr) is tuple: - if len(args) == len(expr) and all(a is b for a, b in zip(expr, args)): + if len(args) == len(expr) and \ + all(a is b for a, b in zip(expr, args, strict=True)): return expr return args elif type(expr) is list: - if len(args) == len(expr) and all(a is b for a, b in zip(expr, args)): + if len(args) == len(expr) and \ + all(a is b for a, b in zip(expr, args, strict=True)): return expr return list(args) elif all(a is b for a, b in zip(expr.args, args, strict=False)): diff --git a/devito/symbolics/search.py b/devito/symbolics/search.py index 55064cbc23..522210fbca 100644 --- a/devito/symbolics/search.py +++ b/devito/symbolics/search.py @@ -11,11 +11,11 @@ from devito.tools import as_tuple, memoized_func __all__ = [ + 'retrieve_accesses', 'retrieve_derivatives', 'retrieve_dimensions', 'retrieve_function_carriers', 'retrieve_functions', - 'retrieve_accesses', 'retrieve_indexed', 'retrieve_symbols', 'retrieve_terminals', From 431482cba868f9a34f5bb99160d99869cbce698b Mon Sep 17 00:00:00 2001 From: Fabio Luporini Date: Tue, 16 Jun 2026 17:05:11 +0100 Subject: [PATCH 36/45] compiler: Use placeholder dtype to compute repr --- devito/types/basic.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/devito/types/basic.py b/devito/types/basic.py index 22a87e848f..f80f23ba0c 100644 --- a/devito/types/basic.py +++ b/devito/types/basic.py @@ -498,6 +498,8 @@ def _C_name(self): @property def _C_ctype(self): + if self.dtype is None: + return CustomDtype('void') return dtype_to_ctype(self.dtype) def _subs(self, old, new, **hints): From 0f93ef93f46e1f7321907120fbd1ae769689c26c Mon Sep 17 00:00:00 2001 From: Fabio Luporini Date: Tue, 16 Jun 2026 17:06:29 +0100 Subject: [PATCH 37/45] tests: Adjust visible_devices testing --- devito/parameters.py | 10 ++++++++-- tests/test_gpu_common.py | 38 ++++++++++++++++++++++++++++---------- tests/test_tools.py | 6 ++++++ 3 files changed, 42 insertions(+), 12 deletions(-) diff --git a/devito/parameters.py b/devito/parameters.py index 2aba1112ad..7fb957ce39 100644 --- a/devito/parameters.py +++ b/devito/parameters.py @@ -286,7 +286,9 @@ def __exit__(self, exc_type, exc_val, traceback): class switchenv(SwitchDecorator): """ - Temporarily set environment variables from a dictionary + Temporarily set environment variables from a dictionary. A value of None + unsets the corresponding environment variable for the duration of the + context. Note: This does not propagate any environment variables that change inside the context manager, so should be used cautiously. @@ -296,7 +298,11 @@ def __init__(self, params): self.params = params def __enter__(self): - os.environ.update(self.params) + for k, v in self.params.items(): + if v is None: + os.environ.pop(k, None) + else: + os.environ[k] = v def __exit__(self, exc_type, exc_val, traceback): os.environ.clear() diff --git a/tests/test_gpu_common.py b/tests/test_gpu_common.py index 4bde13a7fb..a11aea7a5f 100644 --- a/tests/test_gpu_common.py +++ b/tests/test_gpu_common.py @@ -79,6 +79,13 @@ class TestDeviceID: CUDA_VISIBLE_DEVICES are correctly handled. """ + visible_device_envs = ( + 'CUDA_VISIBLE_DEVICES', + 'NVIDIA_VISIBLE_DEVICES', + 'ROCR_VISIBLE_DEVICES', + 'HIP_VISIBLE_DEVICES' + ) + @pytest.mark.parametrize('env_variables', [{"CUDA_VISIBLE_DEVICES": "1"}, {"CUDA_VISIBLE_DEVICES": "1,2"}, {"CUDA_VISIBLE_DEVICES": "1,0"}, @@ -102,12 +109,21 @@ def test_visible_devices(self, env_variables): # All variants in parameterisation should yield deviceid 1 assert argmap1._physical_deviceid == 1 - # Check that physical deviceid is 0 when no environment variables set - op2 = Operator(eq) + def test_default_physical_deviceid(self): + """ + Test that the default physical device ID is 0 when no visible-device + environment variable is set. + """ + grid = Grid(shape=(10, 10)) + u = Function(name='u', grid=grid) + + eq = Eq(u, u+1) + + with switchenv({i: None for i in self.visible_device_envs}): + op2 = Operator(eq) - argmap2 = op2.arguments() - # Default physical deviceid expected to be 0 - assert argmap2._physical_deviceid == 0 + argmap2 = op2.arguments() + assert argmap2._physical_deviceid == 0 @pytest.mark.parallel(mode=2) @pytest.mark.parametrize('visible_devices', [ @@ -144,9 +160,10 @@ def test_visible_devices_mpi(self, visible_devices, mode): assert argmap1._physical_deviceid == expected # In default case, physical deviceid will equal rank - op2 = Operator(eq) - argmap2 = op2.arguments() - assert argmap2._physical_deviceid == rank + with switchenv({i: None for i in self.visible_device_envs}): + op2 = Operator(eq) + argmap2 = op2.arguments() + assert argmap2._physical_deviceid == rank def test_visible_devices_with_devito_deviceid(self): """Test interaction between CUDA_VISIBLE_DEVICES and DEVITO_DEVICEID""" @@ -186,8 +203,9 @@ def test_deviceid_per_rank(self, mode): op = Operator(Eq(u, u+1)) - argmap = op.arguments(deviceid=deviceid) - assert argmap._physical_deviceid == deviceid + with switchenv({i: None for i in self.visible_device_envs}): + argmap = op.arguments(deviceid=deviceid) + assert argmap._physical_deviceid == deviceid class TestCodeGeneration: diff --git a/tests/test_tools.py b/tests/test_tools.py index a4dc6c7e68..2996844ee5 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -336,5 +336,11 @@ def test_switchenv(): # Check a temporary variable is unset inside the context manager assert os.environ.get('TEST_VAR') is None + # Check an existing variable can be temporarily unset inside the context manager + with switchenv({'TEST_VAR_UNSET': 'foo'}): + with switchenv({'TEST_VAR_UNSET': None}): + assert os.environ.get('TEST_VAR_UNSET') is None + assert os.environ['TEST_VAR_UNSET'] == 'foo' + # Make sure the switchenv does not persist to verify switchenv works as intended assert dict(os.environ) == previous_environ From b5d4b7aea51bd46c2f747f871ec89130c9ea1f19 Mon Sep 17 00:00:00 2001 From: Fabio Luporini Date: Wed, 17 Jun 2026 10:04:54 +0100 Subject: [PATCH 38/45] compiler: Simplify reuse_if_untouched --- devito/symbolics/manipulation.py | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/devito/symbolics/manipulation.py b/devito/symbolics/manipulation.py index 41e3669516..44dabb7790 100644 --- a/devito/symbolics/manipulation.py +++ b/devito/symbolics/manipulation.py @@ -369,7 +369,7 @@ def flatten_args(args, op, ignore=None): def pow_to_mul(expr): - if type(expr) in (tuple, list): + if isinstance(expr, (tuple, list)): return reuse_if_untouched(expr, (pow_to_mul(i) for i in expr)) elif q_leaf(expr) or isinstance(expr, Basic): return expr @@ -462,16 +462,11 @@ def reuse_if_untouched(expr, args, evaluate=False): """ args = tuple(args) - if type(expr) is tuple: - if len(args) == len(expr) and \ - all(a is b for a, b in zip(expr, args, strict=True)): - return expr - return args - elif type(expr) is list: + if isinstance(expr, (tuple, list)): if len(args) == len(expr) and \ all(a is b for a, b in zip(expr, args, strict=True)): return expr - return list(args) + return type(expr)(args) elif all(a is b for a, b in zip(expr.args, args, strict=False)): return expr else: From e7afb05819bf58eb8412da9e6df19b5d35c4c6b8 Mon Sep 17 00:00:00 2001 From: Fabio Luporini Date: Wed, 17 Jun 2026 10:32:47 +0100 Subject: [PATCH 39/45] compiler: Simplify get_printer --- devito/ir/cgen/printer.py | 27 +++++++-------------------- devito/ir/iet/visitors.py | 1 + tests/test_dtypes.py | 3 +-- 3 files changed, 9 insertions(+), 22 deletions(-) diff --git a/devito/ir/cgen/printer.py b/devito/ir/cgen/printer.py index 38acc7b7a5..96ae8c56ae 100644 --- a/devito/ir/cgen/printer.py +++ b/devito/ir/cgen/printer.py @@ -18,14 +18,13 @@ from devito.arch.compiler import AOMPCompiler from devito.symbolics.inspection import has_integer_args, sympy_dtype from devito.symbolics.queries import q_leaf -from devito.tools import ctypes_to_cstr, ctypes_vector_mapper, dtype_to_ctype +from devito.tools import ( + ctypes_to_cstr, ctypes_vector_mapper, dtype_to_ctype, memoized_func +) from devito.types.basic import AbstractFunction __all__ = ['BasePrinter', 'ccode', 'get_printer'] -_preset_dtypes = (np.float32, np.float64, np.complex64, np.complex128) -_printer_registry = {} - class BasePrinter(CodePrinter): @@ -452,22 +451,9 @@ def _print_Fallback(self, expr): sympy.printing.str.StrPrinter._print_Add = BasePrinter._print_Add -def get_printer(printer, dtype=None): - try: - registry = _printer_registry[printer] - except KeyError: - default = printer() - registry = {None: default, default.dtype: default} - for i in _preset_dtypes: - registry.setdefault(i, printer(settings={'dtype': i})) - _printer_registry[printer] = registry - - try: - return registry[dtype] - except KeyError: - handle = printer(settings={'dtype': dtype}) - registry[dtype] = handle - return handle +@memoized_func +def get_printer(printer, dtype): + return printer(settings={'dtype': dtype}) def ccode(expr, printer=None, dtype=None): @@ -489,4 +475,5 @@ def ccode(expr, printer=None, dtype=None): if printer is None: from devito.passes.iet.languages.C import CPrinter printer = CPrinter + dtype = printer._default_settings['dtype'] if dtype is None else dtype return get_printer(printer, dtype).doprint(expr, None) diff --git a/devito/ir/iet/visitors.py b/devito/ir/iet/visitors.py index 69468f8256..e3135d097b 100644 --- a/devito/ir/iet/visitors.py +++ b/devito/ir/iet/visitors.py @@ -257,6 +257,7 @@ def __init__(self, *args, printer=None, **kwargs): self.printer = printer def ccode(self, expr, dtype=None): + dtype = self.printer._default_settings['dtype'] if dtype is None else dtype return get_printer(self.printer, dtype).doprint(expr, None) @property diff --git a/tests/test_dtypes.py b/tests/test_dtypes.py index 4206f70057..861ceb17d7 100644 --- a/tests/test_dtypes.py +++ b/tests/test_dtypes.py @@ -205,9 +205,8 @@ def test_math_functions(dtype: np.dtype[np.inexact], def test_printer_registry() -> None: - default = get_printer(CPrinter) + default = get_printer(CPrinter, np.float32) - assert get_printer(CPrinter) is default assert get_printer(CPrinter, np.float32) is default float64 = get_printer(CPrinter, np.float64) From 4747f8c77515f297941dab31704685c5dc087d76 Mon Sep 17 00:00:00 2001 From: Fabio Luporini Date: Wed, 17 Jun 2026 12:05:18 +0100 Subject: [PATCH 40/45] Drop Tag.cached_hash as not safe --- devito/tools/abc.py | 3 --- devito/tools/memoization.py | 2 ++ 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/devito/tools/abc.py b/devito/tools/abc.py index 814eabd7f2..2e489ac3c0 100644 --- a/devito/tools/abc.py +++ b/devito/tools/abc.py @@ -1,7 +1,5 @@ from hashlib import sha1 -from devito.tools.memoization import cached_hash - __all__ = ['Pickable', 'Reconstructable', 'Signer', 'Singleton', 'Stamp', 'Tag'] @@ -36,7 +34,6 @@ def __gt__(self, other): def __ge__(self, other): return self.val >= other.val - @cached_hash def __hash__(self): return hash((self.name, self.val)) diff --git a/devito/tools/memoization.py b/devito/tools/memoization.py index ce9b4a5ba8..7ffc461591 100644 --- a/devito/tools/memoization.py +++ b/devito/tools/memoization.py @@ -21,6 +21,8 @@ def cached_hash(func): Warning: avoid explicitly calling a superclass' cached ``__hash__`` on a subclass instance, as that would stash the superclass hash in ``_mhash``. + + Warning: avoid using it on pickled objects. """ @wraps(func) def wrapper(self): From 58db5e0a86cd0383290eb287eaee183d1b984e9a Mon Sep 17 00:00:00 2001 From: Fabio Luporini Date: Wed, 17 Jun 2026 13:38:18 +0100 Subject: [PATCH 41/45] tests: Use np.random.RandomState where necessary --- tests/test_interpolation.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/test_interpolation.py b/tests/test_interpolation.py index eda7351bb4..75b125d83d 100644 --- a/tests/test_interpolation.py +++ b/tests/test_interpolation.py @@ -887,7 +887,8 @@ def test_interp_complex(self, dtype): sc.coordinates.data[:] = [.5, .5, .5] fc = Function(name="fc", grid=grid, npoint=2, dtype=dtype) - fc.data[:] = np.random.randn(*grid.shape) + 1j * np.random.randn(*grid.shape) + rng = np.random.RandomState(0) + fc.data[:] = rng.randn(*grid.shape) + 1j * rng.randn(*grid.shape) opC = Operator([sc.interpolate(expr=fc)], name="OpC") opC() @@ -903,7 +904,8 @@ def test_interp_complex_and_real(self, dtype): coordinates=sc.coordinates) fc = Function(name="fc", grid=grid, npoint=2, dtype=dtype) - fc.data[:] = np.random.randn(*grid.shape) + 1j * np.random.randn(*grid.shape) + rng = np.random.RandomState(0) + fc.data[:] = rng.randn(*grid.shape) + 1j * rng.randn(*grid.shape) exprs = sc.interpolate(expr=fc) + scre.interpolate(expr=Real(fc)) opC = Operator(exprs, name="OpC") opC() From 679734d7d85e862a0bc30ad065257eff94b072af Mon Sep 17 00:00:00 2001 From: Fabio Luporini Date: Wed, 17 Jun 2026 15:47:19 +0100 Subject: [PATCH 42/45] compiler: Fixup lower_async_objs --- devito/passes/iet/asynchrony.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/devito/passes/iet/asynchrony.py b/devito/passes/iet/asynchrony.py index aa205e818c..74e91c96b7 100644 --- a/devito/passes/iet/asynchrony.py +++ b/devito/passes/iet/asynchrony.py @@ -38,7 +38,7 @@ def pthreadify(graph, **kwargs): AsyncMeta = namedtuple('AsyncMeta', 'sdata threads init shutdown') -@iet_pass +@iet_pass(updates_args=True) def lower_async_objs(iet, **kwargs): # Different actions depending on the Callable type iet, efuncs = _lower_async_objs(iet, **kwargs) From f96ddcf3aba19d4029af9ef7a0ec9e5bc4cb7fc9 Mon Sep 17 00:00:00 2001 From: Fabio Luporini Date: Thu, 18 Jun 2026 09:39:48 +0100 Subject: [PATCH 43/45] compiler: Refactor clean-up --- devito/ir/iet/nodes.py | 2 +- devito/ir/iet/visitors.py | 12 ++++-------- devito/ir/support/basic.py | 4 +--- devito/passes/clusters/fusion.py | 3 +++ 4 files changed, 9 insertions(+), 12 deletions(-) diff --git a/devito/ir/iet/nodes.py b/devito/ir/iet/nodes.py index 583106afdc..bc4b5dc48c 100644 --- a/devito/ir/iet/nodes.py +++ b/devito/ir/iet/nodes.py @@ -1683,7 +1683,7 @@ def _same_as_before(old, new): if type(old) is not type(new): return False - if isinstance(old, dict): + if isinstance(old, dict) and isinstance(new, dict): return old.keys() == new.keys() and all( _same_as_before(v, new[k]) for k, v in old.items() ) diff --git a/devito/ir/iet/visitors.py b/devito/ir/iet/visitors.py index e3135d097b..4208c4b237 100644 --- a/devito/ir/iet/visitors.py +++ b/devito/ir/iet/visitors.py @@ -1560,14 +1560,10 @@ def visit_KernelLaunch(self, o): def reuse_if_unchanged(o, *children, **kwargs): - if children: - same_children = all( - _same_as_before(i, j) for i, j in zip(o.children, children, strict=True) - ) - if not same_children: - return o._rebuild(*children, **kwargs) - - if kwargs and not all(o._same_arg(k, v) for k, v in kwargs.items()): + if children and not _same_as_before(o.children, children): + return o._rebuild(*children, **kwargs) + + if kwargs: return o._rebuild(*children, **kwargs) return o diff --git a/devito/ir/support/basic.py b/devito/ir/support/basic.py index 1c1031d3aa..9283f2c2df 100644 --- a/devito/ir/support/basic.py +++ b/devito/ir/support/basic.py @@ -218,9 +218,7 @@ class TimedAccess(IterationInstance, AccessMode, CacheInstances): """ @classmethod - def _preprocess_args(cls, access, mode, timestamp, ispace=None): - if ispace is None: - ispace = null_ispace + def _preprocess_args(cls, access, mode, timestamp, ispace=null_ispace): return (access, mode, timestamp, ispace), {} def __new__(cls, access, mode, timestamp, ispace=None): diff --git a/devito/passes/clusters/fusion.py b/devito/passes/clusters/fusion.py index c8b6ac2cb1..837b76df13 100644 --- a/devito/passes/clusters/fusion.py +++ b/devito/passes/clusters/fusion.py @@ -24,6 +24,9 @@ @memoized_func(scope='build') def _fusion_hazards(scope0, scope1, prefix): + """ + Classify the dependence hazard that would arise from fusing two scopes. + """ scope = Scope.from_scopes(scope0, scope1) if scope is None: return NO_HAZARD From be11274d17c0231c5db200f90bdd0e3bd2db0f80 Mon Sep 17 00:00:00 2001 From: Fabio Luporini Date: Thu, 18 Jun 2026 10:35:26 +0100 Subject: [PATCH 44/45] compiler: Remove dangerous updates_args feature --- devito/passes/iet/asynchrony.py | 2 +- devito/passes/iet/definitions.py | 2 +- devito/passes/iet/engine.py | 30 +++++++++--------------------- devito/passes/iet/linearization.py | 2 +- tests/test_iet.py | 23 +---------------------- 5 files changed, 13 insertions(+), 46 deletions(-) diff --git a/devito/passes/iet/asynchrony.py b/devito/passes/iet/asynchrony.py index 74e91c96b7..aa205e818c 100644 --- a/devito/passes/iet/asynchrony.py +++ b/devito/passes/iet/asynchrony.py @@ -38,7 +38,7 @@ def pthreadify(graph, **kwargs): AsyncMeta = namedtuple('AsyncMeta', 'sdata threads init shutdown') -@iet_pass(updates_args=True) +@iet_pass def lower_async_objs(iet, **kwargs): # Different actions depending on the Callable type iet, efuncs = _lower_async_objs(iet, **kwargs) diff --git a/devito/passes/iet/definitions.py b/devito/passes/iet/definitions.py index 9fc20bfcee..2f1cce8f10 100644 --- a/devito/passes/iet/definitions.py +++ b/devito/passes/iet/definitions.py @@ -444,7 +444,7 @@ def _inject_definitions(self, iet, storage): return processed, flatten(efuncs) - @iet_pass(updates_args=True) + @iet_pass def place_definitions(self, iet, globs=None, **kwargs): """ Create a new IET where all symbols have been declared, allocated, and diff --git a/devito/passes/iet/engine.py b/devito/passes/iet/engine.py index 617b0ff936..6d1a9527ee 100644 --- a/devito/passes/iet/engine.py +++ b/devito/passes/iet/engine.py @@ -127,21 +127,14 @@ def sync_mapper(self): return found - def apply(self, func, *, updates_args=False, **kwargs): + def apply(self, func, **kwargs): """ Apply ``func`` to all nodes in the Graph. - Parameters - ---------- - updates_args : bool, optional - If True, reconcile Callable parameters and Call arguments before - the graph walk and after each changed node. This is only needed by - passes whose transformation logic depends on already-updated - signatures while the pass is still running. Otherwise, argument - reconciliation is intentionally deferred to ``finalize_args``. + Callable parameters and Call arguments are reconciled before the graph + walk, after each changed node, and after the pass has completed. """ - if updates_args: - _update_args(self) + _update_args(self) dag = create_call_graph(self.root.name, as_hashable(self.efuncs)) @@ -172,8 +165,7 @@ def apply(self, func, *, updates_args=False, **kwargs): efuncs[i] = efunc efuncs.update(dict([(i.name, i) for i in new_efuncs])) - if updates_args: - efuncs = _update_args_efunc(efunc, efuncs, dag) + efuncs = _update_args_efunc(efunc, efuncs, dag) # Minimize code size if len(efuncs) > len(self.efuncs): @@ -182,6 +174,7 @@ def apply(self, func, *, updates_args=False, **kwargs): efuncs = reuse_efuncs(self.root, efuncs, self.sregistry) self.efuncs = efuncs + _update_args(self) # Uniqueness self.includes = filter_ordered(self.includes) @@ -240,24 +233,19 @@ def _update_args(graph): graph.efuncs = efuncs -def iet_pass(func=None, *, updates_args=False): +def iet_pass(func=None): """ Decorate an IET pass. - - ``updates_args=True`` is an opt-in for passes that must observe up-to-date - Callable/Call signatures before and during their own graph walk. Most - passes should leave it False and rely on ``finalize_args`` at the end of - IET lowering. """ if func is None: - return partial(iet_pass, updates_args=updates_args) + return iet_pass if isinstance(func, tuple): assert len(func) == 2 and func[0] is iet_visit call = lambda graph: graph.visit func = func[1] else: - call = lambda graph: partial(graph.apply, updates_args=updates_args) + call = lambda graph: graph.apply @wraps(func) def wrapper(*args, **kwargs): diff --git a/devito/passes/iet/linearization.py b/devito/passes/iet/linearization.py index 78ec84987c..aca2485444 100644 --- a/devito/passes/iet/linearization.py +++ b/devito/passes/iet/linearization.py @@ -46,7 +46,7 @@ def linearize(graph, **kwargs): linearization(graph, key=key, tracker=tracker, **kwargs) -@iet_pass(updates_args=True) +@iet_pass def linearization(iet, key=None, tracker=None, **kwargs): """ Carry out the actual work of `linearize`. diff --git a/tests/test_iet.py b/tests/test_iet.py index 0c3f944a32..e8e8f8444f 100644 --- a/tests/test_iet.py +++ b/tests/test_iet.py @@ -14,8 +14,7 @@ ElementalFunction, FindSymbols, Iteration, KernelLaunch, Lambda, List, Switch, Transformer, filter_iterations, make_efunc, retrieve_iteration_tree ) -from devito.passes.iet import engine as iet_engine -from devito.passes.iet.engine import Graph, iet_pass +from devito.passes.iet.engine import Graph from devito.passes.iet.languages.C import CDataManager from devito.symbolics import ( FLOAT, Byref, Class, FieldFromComposite, InlineIf, ListInitializer, Macro, SizeOf, @@ -540,26 +539,6 @@ def test_complex_array(): "float _Complex **restrict a_vec __attribute__ ((aligned (64)));" -def test_iet_pass_does_not_update_args(monkeypatch): - x = Symbol(name='x') - y = Symbol(name='y') - - foo = Callable('foo', DummyExpr(x, y), 'void', parameters=(x, y)) - graph = Graph(foo) - - @iet_pass - def inject_expr(iet): - body = iet.body._rebuild(body=iet.body.body + (DummyExpr(x, x),)) - return iet._rebuild(body=body), {} - - monkeypatch.setattr(iet_engine, '_update_args', - lambda *args, **kwargs: pytest.fail("_update_args called")) - - inject_expr(graph) - - assert graph.root.parameters is foo.parameters - - def test_special_array_definition(): class MyArray(Array): From 1f261ebf3f7b79dc1ea03584327b171daace0a0d Mon Sep 17 00:00:00 2001 From: Fabio Luporini Date: Fri, 19 Jun 2026 13:29:53 +0100 Subject: [PATCH 45/45] compiler: Refactor fusion.py --- devito/passes/clusters/fusion.py | 57 ++++++++++++++++---------------- 1 file changed, 29 insertions(+), 28 deletions(-) diff --git a/devito/passes/clusters/fusion.py b/devito/passes/clusters/fusion.py index 837b76df13..994b4c9b65 100644 --- a/devito/passes/clusters/fusion.py +++ b/devito/passes/clusters/fusion.py @@ -50,6 +50,34 @@ def _fusion_hazards(scope0, scope1, prefix): return NO_HAZARD +class Key(tuple): + + """ + A "fusion Key" for a Cluster (ClusterGroup) is a hashable tuple such that + two Clusters (ClusterGroups) are topo-fusible if and only if their Key is + identical. + + A Key contains elements that can logically be split into two groups -- the + `strict` and the `weak` components of the Key. Two Clusters (ClusterGroups) + having same `strict` but different `weak` parts are, by definition, not + fusible; however, since at least their `strict` parts match, they can at + least be topologically reordered. + """ + + def __new__(cls, itintervals, guards, syncs, weak): + strict = [itintervals, guards, syncs] + obj = super().__new__(cls, strict + weak) + + obj.itintervals = itintervals + obj.guards = guards + obj.syncs = syncs + + obj.strict = tuple(strict) + obj.weak = tuple(weak) + + return obj + + class Fusion(Queue): """ @@ -106,33 +134,6 @@ def callback(self, cgroups, prefix): else: return [ClusterGroup(processed, prefix)] - class Key(tuple): - - """ - A fusion Key for a Cluster (ClusterGroup) is a hashable tuple such that - two Clusters (ClusterGroups) are topo-fusible if and only if their Key is - identical. - - A Key contains elements that can logically be split into two groups -- the - `strict` and the `weak` components of the Key. Two Clusters (ClusterGroups) - having same `strict` but different `weak` parts are, by definition, not - fusible; however, since at least their `strict` parts match, they can at - least be topologically reordered. - """ - - def __new__(cls, itintervals, guards, syncs, weak): - strict = [itintervals, guards, syncs] - obj = super().__new__(cls, strict + weak) - - obj.itintervals = itintervals - obj.guards = guards - obj.syncs = syncs - - obj.strict = tuple(strict) - obj.weak = tuple(weak) - - return obj - @memoized_meth def _key(self, c): itintervals = frozenset(c.ispace.itintervals) @@ -184,7 +185,7 @@ def _key(self, c): # Promote adjacency of Clusters with same guard weak.append(c.guards) - key = self.Key(itintervals, guards, syncs, weak) + key = Key(itintervals, guards, syncs, weak) return key