@@ -27,9 +27,12 @@ def optimise(expression, parameters):
2727
2828 Returns: An optimised Slate expression
2929 """
30- # 1 ) Block optimisation
30+ # 0 ) Block optimisation
3131 expression = push_block (expression )
3232
33+ # 1) DiagonalTensor optimisation
34+ expression = push_diag (expression )
35+
3336 # 2) Multiplication optimisation
3437 if expression .rank < 2 :
3538 expression = push_mul (expression , parameters )
@@ -113,6 +116,59 @@ def _push_block_block(expr, self, indices):
113116 return block
114117
115118
119+ def push_diag (expression ):
120+ """Executes a Slate compiler optimisation pass.
121+ The optimisation is achieved by pushing DiagonalTensor from the outside to the inside of an expression.
122+
123+ :arg expression: A (potentially unoptimised) Slate expression.
124+
125+ Returns: An optimised Slate expression, where DiagonalTensors are sitting
126+ on terminal tensors whereever possible.
127+ """
128+ mapper = MemoizerArg (_push_diag )
129+ return mapper (expression , False )
130+
131+
132+ @singledispatch
133+ def _push_diag (expr , self , diag ):
134+ raise AssertionError ("Cannot handle terminal type: %s" % type (expr ))
135+
136+
137+ @_push_diag .register (Transpose )
138+ @_push_diag .register (Add )
139+ @_push_diag .register (Negative )
140+ def _push_diag_distributive (expr , self , diag ):
141+ """Distributes the DiagonalTensors into these nodes"""
142+ return type (expr )(* map (self , expr .children , repeat (diag )))
143+
144+
145+ @_push_diag .register (Factorization )
146+ @_push_diag .register (Inverse )
147+ @_push_diag .register (Solve )
148+ @_push_diag .register (Mul )
149+ @_push_diag .register (Tensor )
150+ def _push_diag_stop (expr , self , diag ):
151+ """Diagonal Tensors cannot be pushed further into this set of nodes."""
152+ expr = type (expr )(* map (self , expr .children , repeat (False ))) if not expr .terminal else expr
153+ return DiagonalTensor (expr ) if diag else expr
154+
155+
156+ @_push_diag .register (AssembledVector )
157+ @_push_diag .register (Reciprocal )
158+ def _push_diag_vectors (expr , self , diag ):
159+ """DiagonalTensors should not be pushed onto rank-1 tensors."""
160+ if diag :
161+ raise AssertionError ("It is not legal to define DiagonalTensors on rank-1 tensors." )
162+ else :
163+ return expr
164+
165+
166+ @_push_diag .register (DiagonalTensor )
167+ def _push_diag_diag (expr , self , diag ):
168+ """DiagonalTensors are either pushed down or ignored when wrapped into another DiagonalTensor."""
169+ return self (* expr .children , not diag )
170+
171+
116172def push_mul (tensor , options ):
117173 """Executes a Slate compiler optimisation pass.
118174 The optimisation is achieved by pushing coefficients from
0 commit comments