Skip to content

Commit 4afac04

Browse files
authored
Merge pull request #2259 from firedrakeproject/sv/local-preconditioners
Introducing a few local preconditioners
2 parents 5137875 + e9990b1 commit 4afac04

10 files changed

Lines changed: 958 additions & 101 deletions

File tree

firedrake/assemble.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -693,8 +693,6 @@ def _make_parloops(expr, tensor, bcs, diagonal, fc_params, assembly_rank):
693693
domains = expr.ufl_domains()
694694

695695
if isinstance(expr, slate.TensorBase):
696-
if diagonal:
697-
raise NotImplementedError("Diagonal + slate not supported")
698696
kernels = slac.compile_expression(expr, compiler_parameters=form_compiler_parameters)
699697
else:
700698
kernels = tsfc_interface.compile_form(expr, "form", parameters=form_compiler_parameters, diagonal=diagonal)

firedrake/slate/slac/compiler.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ def generate_loopy_kernel(slate_expr, compiler_parameters=None):
167167

168168
# Create a loopy builder for the Slate expression,
169169
# e.g. contains the loopy kernels coming from TSFC
170-
gem_expr, var2terminal = slate_to_gem(slate_expr)
170+
gem_expr, var2terminal = slate_to_gem(slate_expr, compiler_parameters["slate_compiler"])
171171

172172
scalar_type = compiler_parameters["form_compiler"]["scalar_type"]
173173
slate_loopy, output_arg = gem_to_loopy(gem_expr, var2terminal, scalar_type)

firedrake/slate/slac/optimise.py

Lines changed: 76 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,12 @@ def optimise(expression, parameters):
2727
2828
Returns: An optimised Slate expression
2929
"""
30-
# 1) Block optimisation
30+
# 0) Block optimisation
3131
expression = push_block(expression)
3232

33+
# 1) DiagonalTensor optimisation
34+
expression = push_diag(expression)
35+
3336
# 2) Multiplication optimisation
3437
if expression.rank < 2:
3538
expression = push_mul(expression, parameters)
@@ -70,6 +73,8 @@ def _push_block_transpose(expr, self, indices):
7073

7174
@_push_block.register(Add)
7275
@_push_block.register(Negative)
76+
@_push_block.register(DiagonalTensor)
77+
@_push_block.register(Reciprocal)
7378
def _push_block_distributive(expr, self, indices):
7479
"""Distributes Blocks for these nodes"""
7580
return type(expr)(*map(self, expr.children, repeat(indices))) if indices else expr
@@ -111,6 +116,66 @@ def _push_block_block(expr, self, indices):
111116
return block
112117

113118

119+
def push_diag(expression):
120+
"""Executes a Slate compiler optimisation pass.
121+
The optimisation is achieved by pushing DiagonalTensor from the outside to the inside of an expression.
122+
123+
:arg expression: A (potentially unoptimised) Slate expression.
124+
125+
Returns: An optimised Slate expression, where DiagonalTensors are sitting
126+
on terminal tensors whereever possible.
127+
"""
128+
mapper = MemoizerArg(_push_diag)
129+
return mapper(expression, False)
130+
131+
132+
@singledispatch
133+
def _push_diag(expr, self, diag):
134+
raise AssertionError("Cannot handle terminal type: %s" % type(expr))
135+
136+
137+
@_push_diag.register(Transpose)
138+
@_push_diag.register(Add)
139+
@_push_diag.register(Negative)
140+
def _push_diag_distributive(expr, self, diag):
141+
"""Distributes the DiagonalTensors into these nodes"""
142+
return type(expr)(*map(self, expr.children, repeat(diag)))
143+
144+
145+
@_push_diag.register(Factorization)
146+
@_push_diag.register(Inverse)
147+
@_push_diag.register(Solve)
148+
@_push_diag.register(Mul)
149+
@_push_diag.register(Tensor)
150+
def _push_diag_stop(expr, self, diag):
151+
"""Diagonal Tensors cannot be pushed further into this set of nodes."""
152+
expr = type(expr)(*map(self, expr.children, repeat(False))) if not expr.terminal else expr
153+
return DiagonalTensor(expr) if diag else expr
154+
155+
156+
@_push_diag.register(Block)
157+
def _push_diag_block(expr, self, diag):
158+
"""Diagonal Tensors cannot be pushed further into this set of nodes."""
159+
expr = type(expr)(*map(self, expr.children, repeat(False)), expr._indices) if not expr.terminal else expr
160+
return DiagonalTensor(expr) if diag else expr
161+
162+
163+
@_push_diag.register(AssembledVector)
164+
@_push_diag.register(Reciprocal)
165+
def _push_diag_vectors(expr, self, diag):
166+
"""DiagonalTensors should not be pushed onto rank-1 tensors."""
167+
if diag:
168+
raise AssertionError("It is not legal to define DiagonalTensors on rank-1 tensors.")
169+
else:
170+
return expr
171+
172+
173+
@_push_diag.register(DiagonalTensor)
174+
def _push_diag_diag(expr, self, diag):
175+
"""DiagonalTensors are either pushed down or ignored when wrapped into another DiagonalTensor."""
176+
return self(*expr.children, not diag)
177+
178+
114179
def push_mul(tensor, options):
115180
"""Executes a Slate compiler optimisation pass.
116181
The optimisation is achieved by pushing coefficients from
@@ -179,6 +244,8 @@ def _drop_double_transpose_transpose(expr, self):
179244
@_drop_double_transpose.register(Mul)
180245
@_drop_double_transpose.register(Solve)
181246
@_drop_double_transpose.register(Inverse)
247+
@_drop_double_transpose.register(DiagonalTensor)
248+
@_drop_double_transpose.register(Reciprocal)
182249
def _drop_double_transpose_distributive(expr, self):
183250
"""Distribute into the children of the expression. """
184251
return type(expr)(*map(self, expr.children))
@@ -202,6 +269,8 @@ def _push_mul_tensor(expr, self, state):
202269

203270

204271
@_push_mul.register(AssembledVector)
272+
@_push_mul.register(DiagonalTensor)
273+
@_push_mul.register(Reciprocal)
205274
def _push_mul_vector(expr, self, state):
206275
"""Do not push into AssembledVectors."""
207276
return expr
@@ -220,8 +289,12 @@ def _push_mul_inverse(expr, self, state):
220289
with a coefficient into a Solve via A.inv*b = A.solve(b)
221290
or b*A^{-1}= (A.T.inv*b.T).T = A.T.solve(b.T).T ."""
222291
child, = expr.children
223-
return (Solve(child, state.coeff) if state.pick_op
224-
else Transpose(Solve(Transpose(child), Transpose(state.coeff))))
292+
if expr.diagonal:
293+
# Don't optimise further so that the translation to gem at a later can just spill ]1/a_ii[
294+
return expr * state.coeff if state.pick_op else state.coeff * expr
295+
else:
296+
return (Solve(child, state.coeff) if state.pick_op
297+
else Transpose(Solve(Transpose(child), Transpose(state.coeff))))
225298

226299

227300
@_push_mul.register(Transpose)

firedrake/slate/slac/tsfc_driver.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,8 @@ def compile_terminal_form(tensor, prefix, *, tsfc_parameters=None, coffee=True):
6060
kernels = tsfc_compile(form,
6161
subkernel_prefix,
6262
parameters=tsfc_parameters,
63-
coffee=coffee, split=False)
63+
coffee=coffee, split=False, diagonal=tensor.diagonal)
64+
6465
if kernels:
6566
cxt_k = ContextKernel(tensor=tensor,
6667
coefficients=form.coefficients(),

firedrake/slate/slac/utils.py

Lines changed: 34 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from ufl.algorithms.multifunction import MultiFunction
77

88
from gem import (Literal, Sum, Product, Indexed, ComponentTensor, IndexSum,
9-
Solve, Inverse, Variable, view)
9+
Solve, Inverse, Variable, view, Delta, Index, Division)
1010
from gem import indices as make_indices
1111
from gem.node import Memoizer
1212
from gem.node import pre_traversal as traverse_dags
@@ -148,15 +148,15 @@ def visit_Symbol(self, o, *args, **kwargs):
148148
return SymbolWithFuncallIndexing(o.symbol, o.rank, o.offset)
149149

150150

151-
def slate_to_gem(expression):
151+
def slate_to_gem(expression, options):
152152
"""Convert a slate expression to gem.
153153
154154
:arg expression: A slate expression.
155155
:returns: A singleton list of gem expressions and a mapping from
156156
gem variables to UFL "terminal" forms.
157157
"""
158158

159-
mapper, var2terminal = slate2gem(expression)
159+
mapper, var2terminal = slate2gem(expression, options)
160160
return mapper, var2terminal
161161

162162

@@ -186,9 +186,37 @@ def _slate2gem_block(expr, self):
186186
return view(child, *(slice(idx, idx+extent) for idx, extent in zip(offsets, expr.shape)))
187187

188188

189+
@_slate2gem.register(sl.DiagonalTensor)
190+
def _slate2gem_diagonal(expr, self):
191+
if not self.matfree:
192+
A, = map(self, expr.children)
193+
assert A.shape[0] == A.shape[1]
194+
i, j = (Index(extent=s) for s in A.shape)
195+
return ComponentTensor(Product(Indexed(A, (i, i)), Delta(i, j)), (i, j))
196+
else:
197+
raise NotImplementedError("Diagonals on Slate expressions are \
198+
not implemented in a matrix-free manner yet.")
199+
200+
189201
@_slate2gem.register(sl.Inverse)
190202
def _slate2gem_inverse(expr, self):
191-
return Inverse(*map(self, expr.children))
203+
tensor, = expr.children
204+
if expr.diagonal:
205+
# optimise inverse on diagonal tensor by translating to
206+
# matrix which contains the reciprocal values of the diagonal tensor
207+
A, = map(self, expr.children)
208+
i, j = (Index(extent=s) for s in A.shape)
209+
return ComponentTensor(Product(Division(Literal(1), Indexed(A, (i, i))),
210+
Delta(i, j)), (i, j))
211+
else:
212+
return Inverse(self(tensor))
213+
214+
215+
@_slate2gem.register(sl.Reciprocal)
216+
def _slate2gem_reciprocal(expr, self):
217+
child, = map(self, expr.children)
218+
indices = tuple(make_indices(len(child.shape)))
219+
return ComponentTensor(Division(Literal(1.), Indexed(child, indices)), indices)
192220

193221

194222
@_slate2gem.register(sl.Solve)
@@ -237,9 +265,10 @@ def _slate2gem_factorization(expr, self):
237265
return A
238266

239267

240-
def slate2gem(expression):
268+
def slate2gem(expression, options):
241269
mapper = Memoizer(_slate2gem)
242270
mapper.var2terminal = OrderedDict()
271+
mapper.matfree = options["replace_mul"]
243272
return mapper(expression), mapper.var2terminal
244273

245274

0 commit comments

Comments
 (0)