Skip to content

Commit ef316ff

Browse files
committed
Slate optimiser: push DiagonalTensor as far inside an expression as possible
1 parent 6885823 commit ef316ff

2 files changed

Lines changed: 79 additions & 1 deletion

File tree

firedrake/slate/slac/optimise.py

Lines changed: 57 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,12 @@ def optimise(expression, parameters):
2727
2828
Returns: An optimised Slate expression
2929
"""
30-
# 1) Block optimisation
30+
# 0) Block optimisation
3131
expression = push_block(expression)
3232

33+
# 1) DiagonalTensor optimisation
34+
expression = push_diag(expression)
35+
3336
# 2) Multiplication optimisation
3437
if expression.rank < 2:
3538
expression = push_mul(expression, parameters)
@@ -113,6 +116,59 @@ def _push_block_block(expr, self, indices):
113116
return block
114117

115118

119+
def push_diag(expression):
120+
"""Executes a Slate compiler optimisation pass.
121+
The optimisation is achieved by pushing DiagonalTensor from the outside to the inside of an expression.
122+
123+
:arg expression: A (potentially unoptimised) Slate expression.
124+
125+
Returns: An optimised Slate expression, where DiagonalTensors are sitting
126+
on terminal tensors whereever possible.
127+
"""
128+
mapper = MemoizerArg(_push_diag)
129+
return mapper(expression, False)
130+
131+
132+
@singledispatch
133+
def _push_diag(expr, self, diag):
134+
raise AssertionError("Cannot handle terminal type: %s" % type(expr))
135+
136+
137+
@_push_diag.register(Transpose)
138+
@_push_diag.register(Add)
139+
@_push_diag.register(Negative)
140+
def _push_diag_distributive(expr, self, diag):
141+
"""Distributes the DiagonalTensors into these nodes"""
142+
return type(expr)(*map(self, expr.children, repeat(diag)))
143+
144+
145+
@_push_diag.register(Factorization)
146+
@_push_diag.register(Inverse)
147+
@_push_diag.register(Solve)
148+
@_push_diag.register(Mul)
149+
@_push_diag.register(Tensor)
150+
def _push_diag_stop(expr, self, diag):
151+
"""Diagonal Tensors cannot be pushed further into this set of nodes."""
152+
expr = type(expr)(*map(self, expr.children, repeat(False))) if not expr.terminal else expr
153+
return DiagonalTensor(expr) if diag else expr
154+
155+
156+
@_push_diag.register(AssembledVector)
157+
@_push_diag.register(Reciprocal)
158+
def _push_diag_vectors(expr, self, diag):
159+
"""DiagonalTensors should not be pushed onto rank-1 tensors."""
160+
if diag:
161+
raise AssertionError("It is not legal to define DiagonalTensors on rank-1 tensors.")
162+
else:
163+
return expr
164+
165+
166+
@_push_diag.register(DiagonalTensor)
167+
def _push_diag_diag(expr, self, diag):
168+
"""DiagonalTensors are either pushed down or ignored when wrapped into another DiagonalTensor."""
169+
return self(*expr.children, not diag)
170+
171+
116172
def push_mul(tensor, options):
117173
"""Executes a Slate compiler optimisation pass.
118174
The optimisation is achieved by pushing coefficients from

tests/slate/test_optimise.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -329,6 +329,28 @@ def test_drop_transposes(TC_non_symm):
329329
compare_slate_tensors(expressions, opt_expressions)
330330

331331

332+
#######################################
333+
# Test diagonal optimisation pass
334+
#######################################
335+
def test_push_diagonal(TC_non_symm):
336+
"""Test Optimisers's ability to push DiagonalTensors inside expressions."""
337+
A, C = TC_non_symm
338+
339+
expressions = [DiagonalTensor(A), DiagonalTensor(A+A),
340+
DiagonalTensor(-A), DiagonalTensor(A*A),
341+
DiagonalTensor(A).inv]
342+
opt_expressions = [DiagonalTensor(A), DiagonalTensor(A)+DiagonalTensor(A),
343+
-DiagonalTensor(A), DiagonalTensor(A*A),
344+
DiagonalTensor(A).inv]
345+
compare_tensor_expressions(expressions)
346+
compare_slate_tensors(expressions, opt_expressions)
347+
348+
expressions = [DiagonalTensor(A+A).solve(C)]
349+
opt_expressions = [(DiagonalTensor(A)+DiagonalTensor(A)).solve(C)]
350+
compare_vector_expressions(expressions)
351+
compare_slate_tensors(expressions, opt_expressions)
352+
353+
332354
#######################################
333355
# Helper functions
334356
#######################################

0 commit comments

Comments
 (0)