Leo/nested interpolate#5097
Conversation
| lowered_operand = NestedInterpolateLowerer()(operand) | ||
| if lowered_operand is not operand: | ||
| return get_interpolator(self._ufl_expr_reconstruct_(lowered_operand)) |
There was a problem hiding this comment.
Wouldn't this approach freeze the nested expression? What if we change the numeric values between different calls to assemble on the same expression?
| @DAGTraverser.postorder | ||
| def _(self, o: UFLInterpolate, operand: Expr) -> Expr: | ||
| from firedrake.assemble import assemble | ||
| return as_ufl(assemble(o._ufl_expr_reconstruct_(operand))) |
There was a problem hiding this comment.
Isn't BaseFormAssembler the intended way of traversing the DAG for assemble?
| @process.register(UFLInterpolate) | ||
| @DAGTraverser.postorder | ||
| def _(self, o: UFLInterpolate, operand: Expr) -> Expr: | ||
| from firedrake.assemble import assemble | ||
| return as_ufl(assemble(o._ufl_expr_reconstruct_(operand))) |
There was a problem hiding this comment.
Instead of eagerly assembling each nested Interpolate node, I suggest we replace it with a placeholder Function and have the DAG traverser built a mapping {Function: Interpolator}. This way we delay populating the Functions with values until the outer interpolation has to be assembled . This ensures that the parloops corresponding to the inner interpolations always get on the stack and each is executed before the immediate outer interpolation.
| @process.register(UFLInterpolate) | |
| @DAGTraverser.postorder | |
| def _(self, o: UFLInterpolate, operand: Expr) -> Expr: | |
| from firedrake.assemble import assemble | |
| return as_ufl(assemble(o._ufl_expr_reconstruct_(operand))) | |
| @process.register(UFLInterpolate) | |
| @DAGTraverser.postorder | |
| def _(self, o: UFLInterpolate, operand: Expr) -> Expr: | |
| inner_node = o._ufl_expr_reconstruct_(operand) | |
| fn = Function(o.ufl_function_space()) | |
| self.subs[fn] = get_interpolator(inner_node) | |
| return as_ufl(fn) |
There was a problem hiding this comment.
This mapping could then be used in a dedicated Interpolator sub-class that explicitly chains the composition of callables:
class CompositeInterpolator(Interpolator):
"""
An interpolator for expressions containing nested interpolations
(possibly defined across different meshes).
"""
def __init__(self, outer_expr, subs):
super().__init__(outer_expr)
self.subs = subs # {Function: Interpolator} returned by NestedInterpolateLowerer
self._outer = get_interpolator(outer_expr)
def _get_callable(self, tensor=None, bcs=None, **kwargs):
inner_callables = [
interp._get_callable(tensor=fn)
for fn, interp in self.subs.items()
]
outer_callable = self._outer._get_callable(tensor=tensor, bcs=bcs)
def callable():
for c in inner_callables:
c()
return outer_callable()
return callableWith this we don't need to go through the entire assemble dispatch every single time. Instead, simply executing the callable returned by CompositeInterpolator suffices to ensure everything that's nested gets re-evaluated properly.
There was a problem hiding this comment.
I still think this logic should all go in assemble, as it handles compositions of BaseForm more generically.
The current issue is that symbolic Interpolate objects get reconstructed in BaseFormAssembler as we processes them, but the resulting numerical Interpolator does not get cached on the original expression, but on the processed one, which is then thrown away.
See this PR where we added caching for the Interpolator #4827
There was a problem hiding this comment.
But wouldn't assemble preprocess the expression every time it is called? Wouldn't the symbolic processing introduce overhead as opposed to targeting the callables that handle the execution directly?
There was a problem hiding this comment.
I still think this logic should all go in assemble, as it handles compositions of BaseForm more generically.
I agree with this.
But wouldn't assemble preprocess the expression every time it is called?
I don't know. But if so we should look to cache it. Why not cache the preprocessed expression on the input expression? Then you can cache interpolators on the preprocessed expression that end up being persistent.
There was a problem hiding this comment.
BaseFormAssembler already deals with both primal Expr/BaseForm, and dual Form/BaseForm nodes. The current BaseFormAssembler implementation could be improved if we turn it into a DAGTraverser that we could dispatch on any ufl type, but that's a separate issue.
There was a problem hiding this comment.
Is it necessarily a separate issue? This discussion makes it seem a little like infrastructure surgery is required in order to enable us to handle nested interpolates without tremendous hackery.
There was a problem hiding this comment.
but the resulting numerical
Interpolatordoes not get cached on the original expression, but on the processed one, which is then thrown away.
I am not sure I understand here. Is the processed expression thrown away while the original one gets retained? If so, then making the processed expression persistent would solve the issue as Connor suggests.
Is there a reason why the processed expression gets thrown away?
There was a problem hiding this comment.
assemble will just return the numerical Function/Cofunction. Any intermidiate symbolic expression used to arrive to the numerical result will not be returned by assemble. The right thing would be to cache them, or cache the assembler/interpolator on the original symbolic expression.
There was a problem hiding this comment.
Is it necessarily a separate issue?
The bug can be fixed without restructuring the entire class. And the class restructuring might be done in a way that preserves the bug. The issues are related, but can be dealt with separately.
| # Check for nested Interpolates first | ||
| lowered_operand = NestedInterpolateLowerer()(operand) | ||
| if lowered_operand is not operand: | ||
| return get_interpolator(self._ufl_expr_reconstruct_(lowered_operand)) |
There was a problem hiding this comment.
With my earlier suggestion on using CompositeInterpolator:
| # Check for nested Interpolates first | |
| lowered_operand = NestedInterpolateLowerer()(operand) | |
| if lowered_operand is not operand: | |
| return get_interpolator(self._ufl_expr_reconstruct_(lowered_operand)) | |
| # Check for nested Interpolates first | |
| lowerer = NestedInterpolateLowerer() | |
| lowered_operand = lowerer(operand) | |
| if lowered_operand is not operand: | |
| return CompositeInterpolator(self._ufl_expr_reconstruct_(lowered_operand), lowerer.subs) |
| assert mat_equals(res1, res3) | ||
|
|
||
|
|
||
| def test_nested_interpolate_expr_vom(): |
There was a problem hiding this comment.
Can you reproduce the bug with nested interpolate objects on a single mesh?
There was a problem hiding this comment.
It fails when compiling the dual evaluation kernel:
File "/Users/lac224/Coding/work/firedrake-dev/firedrake/tsfc/driver.py", line 346, in compile_expression_dual_evaluation
evaluation, basis_indices = to_element.dual_evaluation(fn, coordinate_mapping)
~~~~~~~~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/lac224/Coding/work/firedrake-dev/fiat/finat/tensorfiniteelement.py", line 178, in dual_evaluation
expr = fn(x)
File "/Users/lac224/Coding/work/firedrake-dev/firedrake/tsfc/driver.py", line 437, in __call__
gem_expr, = fem.compile_ufl(self.expression, translation_context, point_sum=False)
~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/lac224/Coding/work/firedrake-dev/firedrake/tsfc/fem.py", line 854, in compile_ufl
result = map_expr_dags(context.translator, expressions)
File "/Users/lac224/Coding/work/firedrake-dev/ufl/ufl/corealg/map_dag.py", line 114, in map_expr_dags
r = handlers[v._ufl_typecode_](v, *(vcache[u] for u in v.ufl_operands))
File "/Users/lac224/Coding/work/firedrake-dev/ufl/ufl/corealg/multifunction.py", line 99, in undefined
raise ValueError(f"No handler defined for {o._ufl_class_.__name__}.")
ValueError: No handler defined for Interpolate.
There was a problem hiding this comment.
This might indicate that BaseFormAssembler is not recurring on the operand before construction the Interpolator
There was a problem hiding this comment.
There's a case for adding Interpolate to the form compiler. We would be able to directly generate cell kernels to assemble an aij matrix from inner(Interpolate(grad(u), V2), ...)*dx. In theory we would require to combine compile_dual_evaluation with compile_form
No description provided.