Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 48 additions & 44 deletions devito/passes/clusters/aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ class CireTransformer:
def __init__(self, sregistry, options, platform):
self.sregistry = sregistry
self.platform = platform

self.opt_minstorage = options['min-storage']
self.opt_rotate = options['cire-rotate']
self.opt_ftemps = options['cire-ftemps']
Expand All @@ -125,7 +126,7 @@ def _aliases_from_clusters(self, cgroup, exclude, meta):
for mapper in self._generate(cgroup, exclude):
# Clusters -> AliasList
found = collect(mapper.extracted, meta.ispace, self.opt_minstorage)
exprs, aliases = choose(found, cgroup, mapper, self.opt_mingain)
exprs, aliases = self._choose(found, cgroup, mapper)

# AliasList -> Schedule
schedule = lower_aliases(aliases, meta, self.opt_maxpar)
Expand Down Expand Up @@ -189,6 +190,52 @@ def _lookup_key(self, c):
"""
raise NotImplementedError

def _choose(self, aliases, cgroup, mapper):
"""
Analyze the detected aliases and, after applying a cost model to rule
out the aliases with a bad memory/flops trade-off, inject them into the
original expressions.
"""
exprs = cgroup.exprs

aliases = AliasList(aliases)
if not aliases:
return exprs, aliases

# `score < m` => discarded
# `score > M` => optimized
# `m <= score <= M` => maybe optimized, depends on heuristics
m = self.opt_mingain
M = self.opt_mingain*3
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why 3? Is it just a magic number?


# Filter off the aliases with low score
key = lambda a: a.score >= m
aliases.filter(key)

# Project the candidate aliases into `exprs` to derive the final
# working set
mapper = {k: v for k, v in mapper.items()
if v.free_symbols & set(aliases.aliaseds)}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

set(aliases.aliaseds) is known before creating this mapper and will be the same for every iteration of the list comprehension. I would maybe pull it out of the list comprehension?

templated = [uxreplace(e, mapper) for e in exprs]
owset = wset(templated)

# Filter off the aliases with a weak flop-reduction / working-set tradeoff
key = lambda a: \
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe this wants to be an actual function rather than a lambda - it's pretty unreadable

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, just saw that all this is just lifted from elsewhere. I think the point still stands even if it's not so pressing

a.score > M or \
m <= a.score <= M and (max(len(wset(a.pivot)), 1) >
len(wset(a.pivot) & owset))
aliases.filter(key)

if not aliases:
return exprs, aliases

# Substitute the chosen aliasing sub-expressions
mapper = {k: v for k, v in mapper.items()
if v.free_symbols & set(aliases.aliaseds)}
exprs = [uxreplace(e, mapper) for e in exprs]

return exprs, aliases

def _select(self, variants):
"""
Select the best variant out of a set of `variants`, weighing flops and
Expand Down Expand Up @@ -611,49 +658,6 @@ def collect(extracted, ispace, minstorage):
return aliases


def choose(aliases, cgroup, mapper, mingain):
"""
Analyze the detected aliases and, after applying a cost model to rule out
the aliases with a bad memory/flops trade-off, inject them into the original
expressions.
"""
exprs = cgroup.exprs

aliases = AliasList(aliases)
if not aliases:
return exprs, aliases

# `score < m` => discarded
# `score > M` => optimized
# `m <= score <= M` => maybe discarded, maybe optimized; depends on heuristics
m = mingain
M = mingain*3

# Filter off the aliases with low score
key = lambda a: a.score >= m
aliases.filter(key)

# Project the candidate aliases into `exprs` to derive the final working set
mapper = {k: v for k, v in mapper.items() if v.free_symbols & set(aliases.aliaseds)}
templated = [uxreplace(e, mapper) for e in exprs]
owset = wset(templated)

# Filter off the aliases with a weak flop-reduction / working-set tradeoff
key = lambda a: \
a.score > M or \
m <= a.score <= M and max(len(wset(a.pivot)), 1) > len(wset(a.pivot) & owset)
aliases.filter(key)

if not aliases:
return exprs, aliases

# Substitute the chosen aliasing sub-expressions
mapper = {k: v for k, v in mapper.items() if v.free_symbols & set(aliases.aliaseds)}
exprs = [uxreplace(e, mapper) for e in exprs]

return exprs, aliases


def lower_aliases(aliases, meta, maxpar):
"""
Create a Schedule from an AliasList.
Expand Down
Loading