Skip to content

Commit cb9731e

Browse files
committed
compiler: add rudimentary support for multi-cond buffering
1 parent e40fe22 commit cb9731e

4 files changed

Lines changed: 49 additions & 5 deletions

File tree

devito/ir/clusters/algorithms.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -254,6 +254,7 @@ def guard(clusters):
254254

255255
# Chain together all `cds` conditions from all expressions in `c`
256256
guards = {}
257+
mode = sympy.Or
257258
for cd in cds:
258259
# `BOTTOM` parent implies a guard that lives outside of
259260
# any iteration space, which corresponds to the placeholder None
@@ -270,6 +271,7 @@ def guard(clusters):
270271

271272
# Pull `cd` from any expr
272273
condition = guards.setdefault(k, [])
274+
mode = mode and cd.relation
273275
for e in exprs:
274276
try:
275277
condition.append(e.conditionals[cd])
@@ -284,7 +286,8 @@ def guard(clusters):
284286
conditionals.pop(cd, None)
285287
exprs[i] = e.func(*e.args, conditionals=conditionals)
286288

287-
guards = {d: sympy.And(*v, evaluate=False) for d, v in guards.items()}
289+
# Combination mode is And by default and Or if all conditions are
290+
guards = {d: mode(*v, evaluate=False) for d, v in guards.items()}
288291

289292
# Construct a guarded Cluster
290293
processed.append(c.rebuild(exprs=exprs, guards=guards))

devito/ir/equations/equation.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -223,7 +223,7 @@ def __new__(cls, *args, **kwargs):
223223
else:
224224
cond = diff2sympy(lower_exprs(d.condition))
225225
if d._factor is not None:
226-
cond = sympy.And(cond, GuardFactor(d))
226+
cond = d.relation(cond, GuardFactor(d))
227227
conditionals[d] = cond
228228
# Replace dimension with index
229229
index = d.index

devito/types/dimension.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -860,6 +860,8 @@ class ConditionalDimension(DerivedDimension):
860860
If True, use `self`, rather than the parent Dimension, to
861861
index into arrays. A typical use case is when arrays are accessed
862862
indirectly via the ``condition`` expression.
863+
relation: Or/And, default=And
864+
How this ConditionalDimension will be combined with other ones.
863865
864866
Examples
865867
--------
@@ -913,10 +915,10 @@ class ConditionalDimension(DerivedDimension):
913915
is_Conditional = True
914916

915917
__rkwargs__ = DerivedDimension.__rkwargs__ + \
916-
('factor', 'condition', 'indirect')
918+
('factor', 'condition', 'indirect', 'relation')
917919

918920
def __init_finalize__(self, name, parent=None, factor=None, condition=None,
919-
indirect=False, **kwargs):
921+
indirect=False, relation=sympy.And, **kwargs):
920922
# `parent=None` degenerates to a ConditionalDimension outside of
921923
# any iteration space
922924
if parent is None:
@@ -937,6 +939,7 @@ def __init_finalize__(self, name, parent=None, factor=None, condition=None,
937939

938940
self._condition = condition
939941
self._indirect = indirect
942+
self._relation = relation
940943

941944
@property
942945
def uses_symbolic_factor(self):
@@ -978,6 +981,10 @@ def condition(self):
978981
def indirect(self):
979982
return self._indirect
980983

984+
@property
985+
def relation(self):
986+
return self._relation
987+
981988
@cached_property
982989
def free_symbols(self):
983990
retval = set(super().free_symbols)

tests/test_buffering.py

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
import numpy as np
22
import pytest
3+
from sympy import Or
34

45
from conftest import skipif
56
from devito import (
6-
ConditionalDimension, Constant, Eq, Grid, Operator, SubDimension, SubDomain,
7+
CondEq, ConditionalDimension, Constant, Eq, Grid, Operator, SubDimension, SubDomain,
78
TimeFunction, configuration, switchconfig
89
)
910
from devito.arch.archinfo import AppleArm
@@ -751,3 +752,36 @@ def test_buffer_reuse():
751752

752753
assert all(np.all(usave.data[i-1] == i) for i in range(1, nt + 1))
753754
assert all(np.all(vsave.data[i-1] == i + 1) for i in range(1, nt + 1))
755+
756+
757+
def test_multi_cond():
758+
grid = Grid((3, 3))
759+
nt = 5
760+
761+
x, y = grid.dimensions
762+
763+
factor = 2
764+
ntmod = (nt - 1) * factor + 1
765+
766+
ct1 = ConditionalDimension(name="ct1", parent=grid.time_dim,
767+
factor=factor, relation=Or)
768+
ctend = ConditionalDimension(name="ctend", parent=grid.time_dim,
769+
condition=CondEq(grid.time_dim, ntmod - 2),
770+
relation=Or)
771+
772+
f = TimeFunction(grid=grid, name='f', time_order=0,
773+
space_order=0, save=nt, time_dim=ct1)
774+
T = TimeFunction(grid=grid, name='T', time_order=0, space_order=0)
775+
776+
eqs = [Eq(T, grid.time_dim)]
777+
# this to save times from 0 to nt - 2
778+
eqs.append(Eq(f, T))
779+
# this to save the last time sample nt - 1
780+
eqs.append(Eq(f.forward, T+1, implicit_dims=ctend))
781+
782+
# run operator with buffering
783+
op = Operator(eqs, opt=('streaming', 'buffering'))
784+
op.apply(time_m=0, time_M=ntmod-2)
785+
786+
for i in range(nt):
787+
assert np.allclose(f.data[i], i*2)

0 commit comments

Comments
 (0)