Skip to content

Refactor EnsembleReducedFunctional#4965

Open
JHopeCollins wants to merge 4 commits intomainfrom
JHopeCollins/ensemble-rf-refactor
Open

Refactor EnsembleReducedFunctional#4965
JHopeCollins wants to merge 4 commits intomainfrom
JHopeCollins/ensemble-rf-refactor

Conversation

@JHopeCollins
Copy link
Member

@JHopeCollins JHopeCollins commented Mar 12, 2026

Previous implementation

The EnsembleReducedFunctional implements a functional with many independent terms all depending on the same control, which are calculated in parallel over the ensemble:

$$J(m) = \sum J_i(m)$$

This is three operations composed together:

  1. Broadcast $m$ to all $J_i$
  2. Transform all $J_i(m)$
  3. Reduce $J=\sum J_i$

It also had limited support for using distributed controls, i.e. different $m_i$ for each $J_i$.

$$J(m) = \sum J_i(m_i)$$

This is two operations composed together:

  1. Transform all $J_i(m_i)$
  2. Reduce $J=\sum J_i$

Issues with previous implementation

In the distributed controls case the Controls were Functions on each spatial comm and were not collective over the global comm - this breaks the ReducedFunctional contract and meant that this version would fail the taylor test (e.g. the xfailed tests here:

@pytest.mark.xfail(reason="Taylor's test fails because the inner product \
between the perturbation and gradient is not allreduced \
for `scatter_control=False`.")

In the case with a single control $m$, the user had to implement the local part of the sum (i.e. manually sum the $J_i$ on the local rank themselves).

What does this PR do?

This PR splits the EnsembleReducedFunctional into separate ReducedFunctional classes for the Bcast, Transform, and Reduce. EnsembleReducedFunctional is then re-implemented in terms of these operations.

These all use EnsembleFunction as either the control and/or functional as appropriate so we get collective behaviour and the Taylor tests for a distributed control pass (and we can use them in optimisers).

It also implements an additional OverloadedType called EnsembleAdjVec which is a distributed vector of AdjFloat. It is to AdjFloat what EnsembleFunction is to Function.

What changes will users see?

To use distributed controls a user will need to create an EnsembleFunctionSpace to use an EnsembleFunction as the control. But in return you get the collective behaviour.

A significant API change for both individual and collective controls is that now the $J_i$ on each spatial comm are passed as a separate ReducedFunctional, rather than taping all the $J_i$ and the local reduction on one tape.

  • This means that the user only has to implement $J_i$ rather than any of the reduction.
  • It ensures that each $J_i$ has a different tape, so can be run in any order without polluting state for the derivative calculation.

Comment on lines +264 to +271
**Steps 2-3**: Solve the wave equation and compute the functional.
We create a ``ReducedFunctional`` for each source, which for our
case means one per ensemble member. Creating a ``ReducedFunctional``
per component that we are parallelising over (i.e. per source) -
rather than creating one per ensemble member - we can change
the ensemble parallel partition with minimal changes to the code.::

from firedrake.adjoint import *
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is better 🙃

my_ensemble)
continue_annotation()
J_val = 0.0
with set_working_tape() as tape:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is with set_working_tape() as tape: better here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It makes it really clear which bit of the code is being recorded on that tape, and it makes sure that the ReducedFunctional for each $J_i$ has a different tape.

@JHopeCollins JHopeCollins marked this pull request as ready for review March 12, 2026 16:01
Copy link
Member

@dham dham left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Needs a manual section in the ensemble parallelism chapter to explain the overall mathematical model for ensemble parallel.

Split the classes for the two cases and introduce them in a way that avoids an immediate backward-incompatible change.

from firedrake.adjoint_utils.checkpointing import disk_checkpointing


class EnsembleAdjVec(OverloadedType):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not EnsembleAdjFloat?

# Adjoint action is a reduction so we just piggyback.
# Possibly don't do this if we're being created by the
# reduction rf to avoid infinite recursion.
if not _only_forward:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could this be avoided if you make self._reduce into a cached_property and hence lazily evaluated?

return self.derivative(hessian_input, apply_riesz=apply_riesz)


class EnsembleTransformReducedFunctional(AbstractReducedFunctional):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This name is a bit confusing to me as it makes it seem very generic. It could be equivalent to EnsembleReducedFunctional or it could be an ABC for EnsembleAllgatherReducedFunctional etc.

Maybe something like EnsembleNoReduceReducedFunctional (lol)? EnsemblePipelineReducedFunctional? EnsemblePassthroughReducedFunctional?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants