Skip to content
Open
Show file tree
Hide file tree
Changes from 10 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
1d830b8
implementation of multi-stage time integrators
fernanvr May 5, 2025
7f087b3
Merge remote-tracking branch 'upstream/main' into multi-stage-time-in…
fernanvr Jun 13, 2025
214d882
Return of first PR comments
fernanvr Jun 13, 2025
d6c4d4a
Return of first PR comments
fernanvr Jun 13, 2025
78f8a0b
2nd PR revision
fernanvr Jun 23, 2025
1c9d517
2nd PR revision
fernanvr Jun 23, 2025
11db48b
2rd PR, updating tests and suggestions of 2nd PR revision
fernanvr Jun 25, 2025
83dfb04
3rd PR, updating tests and suggestions of 2nd PR revision
fernanvr Jun 25, 2025
d47a106
4th PR revision, code refining and improving tests
fernanvr Jun 25, 2025
1f93a45
4th PR revision, code refining and improving tests
fernanvr Jun 25, 2025
eea3a52
5th PR revision, one suggestion from EdC and improving tests
fernanvr Jun 26, 2025
11d1429
including two more Runge-Kutta methods and improving tests: checking …
fernanvr Jul 1, 2025
4637ac2
changes to consider coupled Multistage equations
fernanvr Jul 16, 2025
ac1da7e
Improvements of the HORK_EXP
fernanvr Aug 15, 2025
e9b3533
Merge branch 'main' into multi-stage-time-integrator
fernanvr Aug 15, 2025
dc3dd77
Merge remote-tracking branch 'upstream/main' into multi-stage-time-in…
fernanvr Oct 8, 2025
1fd4a02
tuples, improved class names, extensive tests
fernanvr Oct 8, 2025
a0c45c1
improving spacing in some tests
fernanvr Oct 8, 2025
93c6e3f
Add MFE time stepping Jupyter notebook
fernanvr Oct 23, 2025
ef8d1ac
Remove MFE_time_size.ipynb notebook
fernanvr Oct 23, 2025
fa5acac
Update multistage implementation and tests
fernanvr Oct 29, 2025
143d0c2
Return of first PR comments
fernanvr Jun 13, 2025
5f67b91
updating small changes from EdC review on 26-03-2026
fernanvr Mar 26, 2026
552fd7f
Isolate multistage-related changes only
fernanvr Mar 27, 2026
e9d2000
merging two classes of Runge-Kutta
fernanvr Mar 27, 2026
f7c9ea3
Merge full multistage history while keeping clean tree
fernanvr Mar 27, 2026
cf1003c
fixing test_multistage file
fernanvr Apr 6, 2026
1fd480b
Remove devito/ir/equations/algorithms.py and devito/operator/operator…
fernanvr Apr 9, 2026
a875224
implemented suggestions of EdC and Fabio
fernanvr Apr 14, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 36 additions & 2 deletions devito/ir/equations/algorithms.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,15 @@
from devito.tools import (Ordering, as_tuple, flatten, filter_sorted, filter_ordered,
frozendict)
from devito.types import (Dimension, Eq, IgnoreDimSort, SubDimension,
ConditionalDimension)
ConditionalDimension, MultiStage)
from devito.types.array import Array
from devito.types.basic import AbstractFunction
from devito.types.dimension import MultiSubDimension, Thickness
from devito.data.allocators import DataReference
from devito.logger import warning

__all__ = ['dimension_sort', 'lower_exprs', 'concretize_subdims']

__all__ = ['dimension_sort', 'lower_multistage', 'lower_exprs', 'concretize_subdims']


def dimension_sort(expr):
Expand Down Expand Up @@ -95,6 +96,39 @@ def handle_indexed(indexed):
return ordering


def lower_multistage(expressions, **kwargs):
"""
Separating the multi-stage time-integrator scheme in stages:
* If the object is MultiStage, it creates the stages of the method.
"""
return _lower_multistage(expressions, **kwargs)


@singledispatch
def _lower_multistage(expr, **kwargs):
"""
Default handler for expressions that are not MultiStage.
Simply return them in a list.
"""
return [expr]


@_lower_multistage.register(MultiStage)
def _(expr, **kwargs):
"""
Specialized handler for MultiStage expressions.
"""
return expr._evaluate(**kwargs)


@_lower_multistage.register(Iterable)
def _(exprs, **kwargs):
"""
Handle iterables of expressions.
"""
return sum([_lower_multistage(expr, **kwargs) for expr in exprs], [])


def lower_exprs(expressions, subs=None, **kwargs):
"""
Lowering an expression consists of the following passes:
Expand Down
9 changes: 7 additions & 2 deletions devito/operations/solve.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from devito.finite_differences.derivative import Derivative
from devito.tools import as_tuple

from devito.types.multistage import resolve_method

__all__ = ['solve', 'linsolve']


Expand Down Expand Up @@ -56,9 +58,12 @@ def solve(eq, target, **kwargs):

# We need to rebuild the vector/tensor as sympy.solve outputs a tuple of solutions
if len(sols) > 1:
return target.new_from_mat(sols)
sols_temp = target.new_from_mat(sols)
else:
return sols[0]
sols_temp = sols[0]

method = kwargs.get("method", None)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is method a string here, or is it the class for the method? In the latter case, it would remove the need to have the method_registry mapper. Furthermore, it would allow you to have method.resolve(target, sols_temp) here, which is tidier

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a string. The idea is that the user provides a string to identify which time integrator to apply.

return sols_temp if method is None else resolve_method(method)(target, sols_temp)


def linsolve(expr, target, **kwargs):
Expand Down
5 changes: 3 additions & 2 deletions devito/operator/operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
InvalidOperator)
from devito.logger import (debug, info, perf, warning, is_log_enabled_for,
switch_log_level)
from devito.ir.equations import LoweredEq, lower_exprs, concretize_subdims
from devito.ir.equations import LoweredEq, lower_multistage, lower_exprs, concretize_subdims
from devito.ir.clusters import ClusterGroup, clusterize
from devito.ir.iet import (Callable, CInterface, EntryFunction, FindSymbols,
MetaCall, derive_parameters, iet_build)
Expand All @@ -36,7 +36,6 @@
disk_layer)
from devito.types.dimension import Thickness


Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please run the linter (flake8) 🙂

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

__all__ = ['Operator']


Expand Down Expand Up @@ -327,6 +326,8 @@ def _lower_exprs(cls, expressions, **kwargs):
* Apply substitution rules;
* Shift indices for domain alignment.
"""
expressions = lower_multistage(expressions, **kwargs)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should at least be called after expand = ...

and, perhaps, benefit from a more generic name such as lower_timestepping

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we could also move it inside a _lower_dsl, which internally calls _specialize_dsl, just like we already do for expressions/clusters/stree/iet

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That makes sense, thanks! This was part of an earlier approach that after one meeting with Devito's team was decided to be left like a plan b, so it shouldn’t actually be here. I’ll remove it from the PR to appear only the actual approach—though I agree that structuring it that way would make sense if we revisit this idea in the future and I already changed accordingly.

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Quick correction to my previous comment: I realized this part is actually still in use in the current implementation. I’ve updated it taking your suggestions into account (ordering + naming), so it should now reflect what it was intended.


expand = kwargs['options'].get('expand', True)

# Specialization is performed on unevaluated expressions
Expand Down
2 changes: 2 additions & 0 deletions devito/types/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,5 @@
from .relational import * # noqa
from .sparse import * # noqa
from .tensor import * # noqa

from .multistage import * # noqa
179 changes: 179 additions & 0 deletions devito/types/multistage.py
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this file should be moved to somewhere like devito/timestepping/rungekutta.py or devito/timestepping/explicitmultistage.py that way additional timesteppers can be contributed as new files. (I'm thinking about implicit multistage, backward difference formulae etc...)

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I renamed the class to HighOrderRungeKuttaExponential. I realize the name might be confusing since this particular Runge-Kutta is explicit, but “EXP” was intended to highlight the exponential aspect. I’ve also updated the other class names based on your suggestions.

Regarding the file location, it’s currently in /types as recommended by @mloubout (see suggestion). Personally, I think both /timestepping and /types are reasonable options. Perhaps we can discuss this with @EdCaunt and @FabioLuporini to reach a consensus.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree this file doesn't belong to types/

based on https://github.com/devitocodes/devito/pull/2599/changes#r3043562368, we might add it to ir/dsl/rungekutta.py

Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
from .equation import Eq
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make these imports absolute (devito.types.equation)

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

from .dense import Function
from devito.symbolics import uxreplace

# from devito.ir.support import SymbolRegistry

from .array import Array # Trying Array


method_registry = {}

def register_method(cls):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not a fan of using string matching here.

I'm also not sure why this function is needed, especially when the registry itself is just a regular dict

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The idea was to automatically map each time integration method's string representation to its corresponding class, eliminating the need to manually populate the dictionary. This allows users to specify the time integrator using only a string, without needing to explicitly import the corresponding class. Do you recommend another way to do this?

method_registry[cls.__name__] = cls
return cls


def resolve_method(method):
try:
return method_registry[method]
except KeyError:
raise ValueError(f"The time integrator '{method}' is not implemented.")


class MultiStage(Eq):
"""
Abstract base class for multi-stage time integration methods
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good docstring, but is it overindented by a level?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

(e.g., Runge-Kutta schemes) in Devito.

This class represents a symbolic equation of the form `target = rhs`
and provides a mechanism to associate it with a time integration
scheme. The specific integration behavior must be implemented by
subclasses via the `_evaluate` method.

Parameters
----------
lhs : expr-like
The left-hand side of the equation, typically a time-updated Function
(e.g., `u.forward`).
rhs : expr-like, optional
The right-hand side of the equation to integrate. Defaults to 0.
subdomain : SubDomain, optional
A subdomain over which the equation applies.
coefficients : dict, optional
Optional dictionary of symbolic coefficients for the integration.
implicit_dims : tuple, optional
Additional dimensions that should be treated implicitly in the equation.
**kwargs : dict
Additional keyword arguments, such as time integration method selection.

Notes
-----
Subclasses must override the `_evaluate()` method to return a sequence
of update expressions for each stage in the integration process.
"""

def __new__(cls, lhs, rhs=0, subdomain=None, coefficients=None, implicit_dims=None, **kwargs):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This __new__ doesn't seem to do anything. Can it be removed?

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I change it now, for the stages coupling

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I changed a little bit now to consider stage coupling.

return super().__new__(cls, lhs, rhs=rhs, subdomain=subdomain, coefficients=coefficients, implicit_dims=implicit_dims, **kwargs)

def _evaluate(self, **kwargs):
raise NotImplementedError(
f"_evaluate() must be implemented in the subclass {self.__class__.__name__}")


class RK(MultiStage):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
class RK(MultiStage):
class RungeKutta(MultiStage):

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

"""
Base class for explicit Runge-Kutta (RK) time integration methods defined
via a Butcher tableau.

This class handles the general structure of RK schemes by using
the Butcher coefficients (`a`, `b`, `c`) to expand a single equation into
a series of intermediate stages followed by a final update. Subclasses
must define `a`, `b`, and `c` as class attributes.

Parameters
----------
a : list of list of float
The coefficient matrix representing stage dependencies.
b : list of float
The weights for the final combination step.
c : list of float
The time shifts for each intermediate stage (often the row sums of `a`).

Attributes
----------
a : list[list[float]]
Butcher tableau `a` coefficients (stage coupling).
b : list[float]
Butcher tableau `b` coefficients (weights for combining stages).
c : list[float]
Butcher tableau `c` coefficients (stage time positions).
s : int
Number of stages in the RK method, inferred from `b`.
"""

def __init__(self, a=None, b=None, c=None, **kwargs):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This can just be def __init__(self, a: list[list[float | np.number]], b: list[float | np.number], c: list[float | np.number], **kwargs) -> None:, avoiding the need for the validation function

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh, nice...

self.a, self.b, self.c = self._validate(a, b, c)

def _validate(self, a, b, c):
if a is None or b is None or c is None:
raise ValueError("RK subclass must define class attributes of the Butcher's array a, b, and c")
return a, b, c

@property
def s(self):
return len(self.b)

def _evaluate(self, **kwargs):
"""
Generate the stage-wise equations for a Runge-Kutta time integration method.

This method takes a single equation of the form `Eq(u.forward, rhs)` and
expands it into a sequence of intermediate stage evaluations and a final
update equation according to the Runge-Kutta coefficients `a`, `b`, and `c`.

Returns
-------
list of Eq
A list of SymPy Eq objects representing:
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nitpick: they will be Devito Eq objects

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

- `s` stage equations of the form `k_i = rhs evaluated at intermediate state`
- 1 final update equation of the form `u.forward = u + dt * sum(b_i * k_i)`
"""

u = self.lhs.function
rhs = self.rhs
grid = u.grid
t = grid.time_dim
dt = t.spacing

# Create temporary Functions to hold each stage
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nitpick: these are Array now

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

right!

# k = [Array(name=f'{kwargs.get('sregistry').make_name(prefix='k')}', dimensions=grid.shape, grid=grid, dtype=u.dtype) for i in range(self.s)] # Trying Array
k = [Function(name=f'{kwargs.get('sregistry').make_name(prefix='k')}', grid=grid, space_order=u.space_order, dtype=u.dtype)
for i in range(self.s)]

stage_eqs = []

# Build each stage
for i in range(self.s):
u_temp = u + dt * sum(aij * kj for aij, kj in zip(self.a[i][:i], k[:i]))
t_shift = t + self.c[i] * dt

# Evaluate RHS at intermediate value
stage_rhs = uxreplace(rhs, {u: u_temp, t: t_shift})
stage_eqs.append(Eq(k[i], stage_rhs))

# Final update: u.forward = u + dt * sum(b_i * k_i)
u_next = u + dt * sum(bi * ki for bi, ki in zip(self.b, k))
stage_eqs.append(Eq(u.forward, u_next))

return stage_eqs


@register_method
class RK44(RK):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think RK4 and indeed all RK methods should be instances of the RK Class

Then you no longer need all of the boilerplate code below, which is just setting up Butcher tableau

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or the coefficients should be class attributes and set by the child class

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think I didn't understand well. Aren’t they already implemented as described in the second option of your comment?

"""
Classic 4th-order Runge-Kutta (RK4) time integration method.

This class implements the classic explicit Runge-Kutta method of order 4 (RK44).
It uses four intermediate stages and specific Butcher coefficients to achieve
high accuracy while remaining explicit.

Attributes
----------
a : list[list[float]]
Coefficients of the `a` matrix for intermediate stage coupling.
b : list[float]
Weights for final combination.
c : list[float]
Time positions of intermediate stages.
"""
a = [[0, 0, 0, 0],
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would set these as tuples in the __init__. Definitely should not be mutable if set on a class level.

I would personally instead have a

def __init__(self):
    a = (...
    b = (...
    c = (...
    super.__init__(a=a, b=b, c=c)

Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did something like that... could you check?

[1/2, 0, 0, 0],
[0, 1/2, 0, 0],
[0, 0, 1, 0]]
b = [1/6, 1/3, 1/3, 1/6]
c = [0, 1/2, 1/2, 1]

def __init__(self, *args, **kwargs):
super().__init__(a=self.a, b=self.b, c=self.c)

Loading