@@ -27,9 +27,12 @@ def optimise(expression, parameters):
2727
2828 Returns: An optimised Slate expression
2929 """
30- # 1 ) Block optimisation
30+ # 0 ) Block optimisation
3131 expression = push_block (expression )
3232
33+ # 1) DiagonalTensor optimisation
34+ expression = push_diag (expression )
35+
3336 # 2) Multiplication optimisation
3437 if expression .rank < 2 :
3538 expression = push_mul (expression , parameters )
@@ -70,6 +73,8 @@ def _push_block_transpose(expr, self, indices):
7073
7174@_push_block .register (Add )
7275@_push_block .register (Negative )
76+ @_push_block .register (DiagonalTensor )
77+ @_push_block .register (Reciprocal )
7378def _push_block_distributive (expr , self , indices ):
7479 """Distributes Blocks for these nodes"""
7580 return type (expr )(* map (self , expr .children , repeat (indices ))) if indices else expr
@@ -111,6 +116,66 @@ def _push_block_block(expr, self, indices):
111116 return block
112117
113118
119+ def push_diag (expression ):
120+ """Executes a Slate compiler optimisation pass.
121+ The optimisation is achieved by pushing DiagonalTensor from the outside to the inside of an expression.
122+
123+ :arg expression: A (potentially unoptimised) Slate expression.
124+
125+ Returns: An optimised Slate expression, where DiagonalTensors are sitting
126+ on terminal tensors whereever possible.
127+ """
128+ mapper = MemoizerArg (_push_diag )
129+ return mapper (expression , False )
130+
131+
132+ @singledispatch
133+ def _push_diag (expr , self , diag ):
134+ raise AssertionError ("Cannot handle terminal type: %s" % type (expr ))
135+
136+
137+ @_push_diag .register (Transpose )
138+ @_push_diag .register (Add )
139+ @_push_diag .register (Negative )
140+ def _push_diag_distributive (expr , self , diag ):
141+ """Distributes the DiagonalTensors into these nodes"""
142+ return type (expr )(* map (self , expr .children , repeat (diag )))
143+
144+
145+ @_push_diag .register (Factorization )
146+ @_push_diag .register (Inverse )
147+ @_push_diag .register (Solve )
148+ @_push_diag .register (Mul )
149+ @_push_diag .register (Tensor )
150+ def _push_diag_stop (expr , self , diag ):
151+ """Diagonal Tensors cannot be pushed further into this set of nodes."""
152+ expr = type (expr )(* map (self , expr .children , repeat (False ))) if not expr .terminal else expr
153+ return DiagonalTensor (expr ) if diag else expr
154+
155+
156+ @_push_diag .register (Block )
157+ def _push_diag_block (expr , self , diag ):
158+ """Diagonal Tensors cannot be pushed further into this set of nodes."""
159+ expr = type (expr )(* map (self , expr .children , repeat (False )), expr ._indices ) if not expr .terminal else expr
160+ return DiagonalTensor (expr ) if diag else expr
161+
162+
163+ @_push_diag .register (AssembledVector )
164+ @_push_diag .register (Reciprocal )
165+ def _push_diag_vectors (expr , self , diag ):
166+ """DiagonalTensors should not be pushed onto rank-1 tensors."""
167+ if diag :
168+ raise AssertionError ("It is not legal to define DiagonalTensors on rank-1 tensors." )
169+ else :
170+ return expr
171+
172+
173+ @_push_diag .register (DiagonalTensor )
174+ def _push_diag_diag (expr , self , diag ):
175+ """DiagonalTensors are either pushed down or ignored when wrapped into another DiagonalTensor."""
176+ return self (* expr .children , not diag )
177+
178+
114179def push_mul (tensor , options ):
115180 """Executes a Slate compiler optimisation pass.
116181 The optimisation is achieved by pushing coefficients from
@@ -179,6 +244,8 @@ def _drop_double_transpose_transpose(expr, self):
179244@_drop_double_transpose .register (Mul )
180245@_drop_double_transpose .register (Solve )
181246@_drop_double_transpose .register (Inverse )
247+ @_drop_double_transpose .register (DiagonalTensor )
248+ @_drop_double_transpose .register (Reciprocal )
182249def _drop_double_transpose_distributive (expr , self ):
183250 """Distribute into the children of the expression. """
184251 return type (expr )(* map (self , expr .children ))
@@ -202,6 +269,8 @@ def _push_mul_tensor(expr, self, state):
202269
203270
204271@_push_mul .register (AssembledVector )
272+ @_push_mul .register (DiagonalTensor )
273+ @_push_mul .register (Reciprocal )
205274def _push_mul_vector (expr , self , state ):
206275 """Do not push into AssembledVectors."""
207276 return expr
@@ -220,8 +289,12 @@ def _push_mul_inverse(expr, self, state):
220289 with a coefficient into a Solve via A.inv*b = A.solve(b)
221290 or b*A^{-1}= (A.T.inv*b.T).T = A.T.solve(b.T).T ."""
222291 child , = expr .children
223- return (Solve (child , state .coeff ) if state .pick_op
224- else Transpose (Solve (Transpose (child ), Transpose (state .coeff ))))
292+ if expr .diagonal :
293+ # Don't optimise further so that the translation to gem at a later can just spill ]1/a_ii[
294+ return expr * state .coeff if state .pick_op else state .coeff * expr
295+ else :
296+ return (Solve (child , state .coeff ) if state .pick_op
297+ else Transpose (Solve (Transpose (child ), Transpose (state .coeff ))))
225298
226299
227300@_push_mul .register (Transpose )
0 commit comments