From 6606bc814bc1349d309f960c043115f2db23af19 Mon Sep 17 00:00:00 2001 From: Fabio Luporini Date: Sat, 31 Jan 2026 13:35:43 +0000 Subject: [PATCH] compiler: Turn aliases' choose() into an instance method --- devito/passes/clusters/aliases.py | 92 ++++++++++++++++--------------- 1 file changed, 48 insertions(+), 44 deletions(-) diff --git a/devito/passes/clusters/aliases.py b/devito/passes/clusters/aliases.py index f0c25a74be..a126f96684 100644 --- a/devito/passes/clusters/aliases.py +++ b/devito/passes/clusters/aliases.py @@ -111,6 +111,7 @@ class CireTransformer: def __init__(self, sregistry, options, platform): self.sregistry = sregistry self.platform = platform + self.opt_minstorage = options['min-storage'] self.opt_rotate = options['cire-rotate'] self.opt_ftemps = options['cire-ftemps'] @@ -125,7 +126,7 @@ def _aliases_from_clusters(self, cgroup, exclude, meta): for mapper in self._generate(cgroup, exclude): # Clusters -> AliasList found = collect(mapper.extracted, meta.ispace, self.opt_minstorage) - exprs, aliases = choose(found, cgroup, mapper, self.opt_mingain) + exprs, aliases = self._choose(found, cgroup, mapper) # AliasList -> Schedule schedule = lower_aliases(aliases, meta, self.opt_maxpar) @@ -189,6 +190,52 @@ def _lookup_key(self, c): """ raise NotImplementedError + def _choose(self, aliases, cgroup, mapper): + """ + Analyze the detected aliases and, after applying a cost model to rule + out the aliases with a bad memory/flops trade-off, inject them into the + original expressions. + """ + exprs = cgroup.exprs + + aliases = AliasList(aliases) + if not aliases: + return exprs, aliases + + # `score < m` => discarded + # `score > M` => optimized + # `m <= score <= M` => maybe optimized, depends on heuristics + m = self.opt_mingain + M = self.opt_mingain*3 + + # Filter off the aliases with low score + key = lambda a: a.score >= m + aliases.filter(key) + + # Project the candidate aliases into `exprs` to derive the final + # working set + mapper = {k: v for k, v in mapper.items() + if v.free_symbols & set(aliases.aliaseds)} + templated = [uxreplace(e, mapper) for e in exprs] + owset = wset(templated) + + # Filter off the aliases with a weak flop-reduction / working-set tradeoff + key = lambda a: \ + a.score > M or \ + m <= a.score <= M and (max(len(wset(a.pivot)), 1) > + len(wset(a.pivot) & owset)) + aliases.filter(key) + + if not aliases: + return exprs, aliases + + # Substitute the chosen aliasing sub-expressions + mapper = {k: v for k, v in mapper.items() + if v.free_symbols & set(aliases.aliaseds)} + exprs = [uxreplace(e, mapper) for e in exprs] + + return exprs, aliases + def _select(self, variants): """ Select the best variant out of a set of `variants`, weighing flops and @@ -611,49 +658,6 @@ def collect(extracted, ispace, minstorage): return aliases -def choose(aliases, cgroup, mapper, mingain): - """ - Analyze the detected aliases and, after applying a cost model to rule out - the aliases with a bad memory/flops trade-off, inject them into the original - expressions. - """ - exprs = cgroup.exprs - - aliases = AliasList(aliases) - if not aliases: - return exprs, aliases - - # `score < m` => discarded - # `score > M` => optimized - # `m <= score <= M` => maybe discarded, maybe optimized; depends on heuristics - m = mingain - M = mingain*3 - - # Filter off the aliases with low score - key = lambda a: a.score >= m - aliases.filter(key) - - # Project the candidate aliases into `exprs` to derive the final working set - mapper = {k: v for k, v in mapper.items() if v.free_symbols & set(aliases.aliaseds)} - templated = [uxreplace(e, mapper) for e in exprs] - owset = wset(templated) - - # Filter off the aliases with a weak flop-reduction / working-set tradeoff - key = lambda a: \ - a.score > M or \ - m <= a.score <= M and max(len(wset(a.pivot)), 1) > len(wset(a.pivot) & owset) - aliases.filter(key) - - if not aliases: - return exprs, aliases - - # Substitute the chosen aliasing sub-expressions - mapper = {k: v for k, v in mapper.items() if v.free_symbols & set(aliases.aliaseds)} - exprs = [uxreplace(e, mapper) for e in exprs] - - return exprs, aliases - - def lower_aliases(aliases, meta, maxpar): """ Create a Schedule from an AliasList.