Conversation
| **Steps 2-3**: Solve the wave equation and compute the functional. | ||
| We create a ``ReducedFunctional`` for each source, which for our | ||
| case means one per ensemble member. Creating a ``ReducedFunctional`` | ||
| per component that we are parallelising over (i.e. per source) - | ||
| rather than creating one per ensemble member - we can change | ||
| the ensemble parallel partition with minimal changes to the code.:: | ||
|
|
||
| from firedrake.adjoint import * |
| my_ensemble) | ||
| continue_annotation() | ||
| J_val = 0.0 | ||
| with set_working_tape() as tape: |
There was a problem hiding this comment.
Why is with set_working_tape() as tape: better here?
There was a problem hiding this comment.
It makes it really clear which bit of the code is being recorded on that tape, and it makes sure that the ReducedFunctional for each
dham
left a comment
There was a problem hiding this comment.
Needs a manual section in the ensemble parallelism chapter to explain the overall mathematical model for ensemble parallel.
Split the classes for the two cases and introduce them in a way that avoids an immediate backward-incompatible change.
| from firedrake.adjoint_utils.checkpointing import disk_checkpointing | ||
|
|
||
|
|
||
| class EnsembleAdjVec(OverloadedType): |
There was a problem hiding this comment.
Why not EnsembleAdjFloat?
| # Adjoint action is a reduction so we just piggyback. | ||
| # Possibly don't do this if we're being created by the | ||
| # reduction rf to avoid infinite recursion. | ||
| if not _only_forward: |
There was a problem hiding this comment.
Could this be avoided if you make self._reduce into a cached_property and hence lazily evaluated?
| return self.derivative(hessian_input, apply_riesz=apply_riesz) | ||
|
|
||
|
|
||
| class EnsembleTransformReducedFunctional(AbstractReducedFunctional): |
There was a problem hiding this comment.
This name is a bit confusing to me as it makes it seem very generic. It could be equivalent to EnsembleReducedFunctional or it could be an ABC for EnsembleAllgatherReducedFunctional etc.
Maybe something like EnsembleNoReduceReducedFunctional (lol)? EnsemblePipelineReducedFunctional? EnsemblePassthroughReducedFunctional?
Previous implementation
The
EnsembleReducedFunctionalimplements a functional with many independent terms all depending on the same control, which are calculated in parallel over the ensemble:This is three operations composed together:
It also had limited support for using distributed controls, i.e. different$m_i$ for each $J_i$ .
This is two operations composed together:
Issues with previous implementation
In the distributed controls case the
ControlswereFunctionson each spatial comm and were not collective over the global comm - this breaks theReducedFunctionalcontract and meant that this version would fail the taylor test (e.g. the xfailed tests here:firedrake/tests/firedrake/adjoint/test_ensemble_reduced_functional.py
Lines 99 to 101 in a101140
In the case with a single control$m$ , the user had to implement the local part of the sum (i.e. manually sum the $J_i$ on the local rank themselves).
What does this PR do?
This PR splits the
EnsembleReducedFunctionalinto separateReducedFunctionalclasses for theBcast,Transform, andReduce.EnsembleReducedFunctionalis then re-implemented in terms of these operations.These all use
EnsembleFunctionas either the control and/or functional as appropriate so we get collective behaviour and the Taylor tests for a distributed control pass (and we can use them in optimisers).It also implements an additional
OverloadedTypecalledEnsembleAdjVecwhich is a distributed vector ofAdjFloat. It is toAdjFloatwhatEnsembleFunctionis toFunction.What changes will users see?
To use distributed controls a user will need to create an
EnsembleFunctionSpaceto use anEnsembleFunctionas the control. But in return you get the collective behaviour.A significant API change for both individual and collective controls is that now the$J_i$ on each spatial comm are passed as a separate $J_i$ and the local reduction on one tape.
ReducedFunctional, rather than taping all the