-
Notifications
You must be signed in to change notification settings - Fork 256
dsl: Introduce abstractions for multi-stage time integrators #2599
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 10 commits
1d830b8
7f087b3
214d882
d6c4d4a
78f8a0b
1c9d517
11db48b
83dfb04
d47a106
1f93a45
eea3a52
11d1429
4637ac2
ac1da7e
e9b3533
dc3dd77
1fd4a02
a0c45c1
93c6e3f
ef8d1ac
fa5acac
143d0c2
5f67b91
552fd7f
e9d2000
f7c9ea3
cf1003c
1fd480b
a875224
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -15,7 +15,7 @@ | |
| InvalidOperator) | ||
| from devito.logger import (debug, info, perf, warning, is_log_enabled_for, | ||
| switch_log_level) | ||
| from devito.ir.equations import LoweredEq, lower_exprs, concretize_subdims | ||
| from devito.ir.equations import LoweredEq, lower_multistage, lower_exprs, concretize_subdims | ||
| from devito.ir.clusters import ClusterGroup, clusterize | ||
| from devito.ir.iet import (Callable, CInterface, EntryFunction, FindSymbols, | ||
| MetaCall, derive_parameters, iet_build) | ||
|
|
@@ -36,7 +36,6 @@ | |
| disk_layer) | ||
| from devito.types.dimension import Thickness | ||
|
|
||
|
|
||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please run the linter (
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||
| __all__ = ['Operator'] | ||
|
|
||
|
|
||
|
|
@@ -327,6 +326,8 @@ def _lower_exprs(cls, expressions, **kwargs): | |
| * Apply substitution rules; | ||
| * Shift indices for domain alignment. | ||
| """ | ||
| expressions = lower_multistage(expressions, **kwargs) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This should at least be called after and, perhaps, benefit from a more generic name such as
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. we could also move it inside a
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That makes sense, thanks! This was part of an earlier approach that after one meeting with Devito's team was decided to be left like a plan b, so it shouldn’t actually be here. I’ll remove it from the PR to appear only the actual approach—though I agree that structuring it that way would make sense if we revisit this idea in the future and I already changed accordingly.
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Quick correction to my previous comment: I realized this part is actually still in use in the current implementation. I’ve updated it taking your suggestions into account (ordering + naming), so it should now reflect what it was intended. |
||
|
|
||
| expand = kwargs['options'].get('expand', True) | ||
|
|
||
| # Specialization is performed on unevaluated expressions | ||
|
|
||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this file should be moved to somewhere like
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I renamed the class to Regarding the file location, it’s currently in
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I agree this file doesn't belong to based on https://github.com/devitocodes/devito/pull/2599/changes#r3043562368, we might add it to |
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
| @@ -0,0 +1,179 @@ | ||||||
| from .equation import Eq | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Make these imports absolute (
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||||||
| from .dense import Function | ||||||
| from devito.symbolics import uxreplace | ||||||
|
|
||||||
| # from devito.ir.support import SymbolRegistry | ||||||
|
|
||||||
| from .array import Array # Trying Array | ||||||
|
|
||||||
|
|
||||||
| method_registry = {} | ||||||
|
|
||||||
| def register_method(cls): | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm not a fan of using string matching here. I'm also not sure why this function is needed, especially when the registry itself is just a regular dict
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The idea was to automatically map each time integration method's string representation to its corresponding class, eliminating the need to manually populate the dictionary. This allows users to specify the time integrator using only a string, without needing to explicitly import the corresponding class. Do you recommend another way to do this? |
||||||
| method_registry[cls.__name__] = cls | ||||||
| return cls | ||||||
|
|
||||||
|
|
||||||
| def resolve_method(method): | ||||||
| try: | ||||||
| return method_registry[method] | ||||||
| except KeyError: | ||||||
| raise ValueError(f"The time integrator '{method}' is not implemented.") | ||||||
|
|
||||||
|
|
||||||
| class MultiStage(Eq): | ||||||
| """ | ||||||
| Abstract base class for multi-stage time integration methods | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good docstring, but is it overindented by a level?
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||||||
| (e.g., Runge-Kutta schemes) in Devito. | ||||||
|
|
||||||
| This class represents a symbolic equation of the form `target = rhs` | ||||||
| and provides a mechanism to associate it with a time integration | ||||||
| scheme. The specific integration behavior must be implemented by | ||||||
| subclasses via the `_evaluate` method. | ||||||
|
|
||||||
| Parameters | ||||||
| ---------- | ||||||
| lhs : expr-like | ||||||
| The left-hand side of the equation, typically a time-updated Function | ||||||
| (e.g., `u.forward`). | ||||||
| rhs : expr-like, optional | ||||||
| The right-hand side of the equation to integrate. Defaults to 0. | ||||||
| subdomain : SubDomain, optional | ||||||
| A subdomain over which the equation applies. | ||||||
| coefficients : dict, optional | ||||||
| Optional dictionary of symbolic coefficients for the integration. | ||||||
| implicit_dims : tuple, optional | ||||||
| Additional dimensions that should be treated implicitly in the equation. | ||||||
| **kwargs : dict | ||||||
| Additional keyword arguments, such as time integration method selection. | ||||||
|
|
||||||
| Notes | ||||||
| ----- | ||||||
| Subclasses must override the `_evaluate()` method to return a sequence | ||||||
| of update expressions for each stage in the integration process. | ||||||
| """ | ||||||
|
|
||||||
| def __new__(cls, lhs, rhs=0, subdomain=None, coefficients=None, implicit_dims=None, **kwargs): | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, I change it now, for the stages coupling
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, I changed a little bit now to consider stage coupling. |
||||||
| return super().__new__(cls, lhs, rhs=rhs, subdomain=subdomain, coefficients=coefficients, implicit_dims=implicit_dims, **kwargs) | ||||||
|
|
||||||
| def _evaluate(self, **kwargs): | ||||||
| raise NotImplementedError( | ||||||
| f"_evaluate() must be implemented in the subclass {self.__class__.__name__}") | ||||||
|
|
||||||
|
|
||||||
| class RK(MultiStage): | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||||||
| """ | ||||||
| Base class for explicit Runge-Kutta (RK) time integration methods defined | ||||||
| via a Butcher tableau. | ||||||
|
|
||||||
| This class handles the general structure of RK schemes by using | ||||||
| the Butcher coefficients (`a`, `b`, `c`) to expand a single equation into | ||||||
| a series of intermediate stages followed by a final update. Subclasses | ||||||
| must define `a`, `b`, and `c` as class attributes. | ||||||
|
|
||||||
| Parameters | ||||||
| ---------- | ||||||
| a : list of list of float | ||||||
| The coefficient matrix representing stage dependencies. | ||||||
| b : list of float | ||||||
| The weights for the final combination step. | ||||||
| c : list of float | ||||||
| The time shifts for each intermediate stage (often the row sums of `a`). | ||||||
|
|
||||||
| Attributes | ||||||
| ---------- | ||||||
| a : list[list[float]] | ||||||
| Butcher tableau `a` coefficients (stage coupling). | ||||||
| b : list[float] | ||||||
| Butcher tableau `b` coefficients (weights for combining stages). | ||||||
| c : list[float] | ||||||
| Butcher tableau `c` coefficients (stage time positions). | ||||||
| s : int | ||||||
| Number of stages in the RK method, inferred from `b`. | ||||||
| """ | ||||||
|
|
||||||
| def __init__(self, a=None, b=None, c=None, **kwargs): | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This can just be
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. oh, nice... |
||||||
| self.a, self.b, self.c = self._validate(a, b, c) | ||||||
|
|
||||||
| def _validate(self, a, b, c): | ||||||
| if a is None or b is None or c is None: | ||||||
| raise ValueError("RK subclass must define class attributes of the Butcher's array a, b, and c") | ||||||
| return a, b, c | ||||||
|
|
||||||
| @property | ||||||
| def s(self): | ||||||
| return len(self.b) | ||||||
|
|
||||||
| def _evaluate(self, **kwargs): | ||||||
| """ | ||||||
| Generate the stage-wise equations for a Runge-Kutta time integration method. | ||||||
|
|
||||||
| This method takes a single equation of the form `Eq(u.forward, rhs)` and | ||||||
| expands it into a sequence of intermediate stage evaluations and a final | ||||||
| update equation according to the Runge-Kutta coefficients `a`, `b`, and `c`. | ||||||
|
|
||||||
| Returns | ||||||
| ------- | ||||||
| list of Eq | ||||||
| A list of SymPy Eq objects representing: | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nitpick: they will be Devito
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. done |
||||||
| - `s` stage equations of the form `k_i = rhs evaluated at intermediate state` | ||||||
| - 1 final update equation of the form `u.forward = u + dt * sum(b_i * k_i)` | ||||||
| """ | ||||||
|
|
||||||
| u = self.lhs.function | ||||||
| rhs = self.rhs | ||||||
| grid = u.grid | ||||||
| t = grid.time_dim | ||||||
| dt = t.spacing | ||||||
|
|
||||||
| # Create temporary Functions to hold each stage | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Nitpick: these are
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. right! |
||||||
| # k = [Array(name=f'{kwargs.get('sregistry').make_name(prefix='k')}', dimensions=grid.shape, grid=grid, dtype=u.dtype) for i in range(self.s)] # Trying Array | ||||||
| k = [Function(name=f'{kwargs.get('sregistry').make_name(prefix='k')}', grid=grid, space_order=u.space_order, dtype=u.dtype) | ||||||
| for i in range(self.s)] | ||||||
|
|
||||||
| stage_eqs = [] | ||||||
|
|
||||||
| # Build each stage | ||||||
| for i in range(self.s): | ||||||
| u_temp = u + dt * sum(aij * kj for aij, kj in zip(self.a[i][:i], k[:i])) | ||||||
| t_shift = t + self.c[i] * dt | ||||||
|
|
||||||
| # Evaluate RHS at intermediate value | ||||||
| stage_rhs = uxreplace(rhs, {u: u_temp, t: t_shift}) | ||||||
| stage_eqs.append(Eq(k[i], stage_rhs)) | ||||||
|
|
||||||
| # Final update: u.forward = u + dt * sum(b_i * k_i) | ||||||
| u_next = u + dt * sum(bi * ki for bi, ki in zip(self.b, k)) | ||||||
| stage_eqs.append(Eq(u.forward, u_next)) | ||||||
|
|
||||||
| return stage_eqs | ||||||
|
|
||||||
|
|
||||||
| @register_method | ||||||
| class RK44(RK): | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think Then you no longer need all of the boilerplate code below, which is just setting up Butcher tableau
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Or the coefficients should be class attributes and set by the child class
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think I didn't understand well. Aren’t they already implemented as described in the second option of your comment? |
||||||
| """ | ||||||
| Classic 4th-order Runge-Kutta (RK4) time integration method. | ||||||
|
|
||||||
| This class implements the classic explicit Runge-Kutta method of order 4 (RK44). | ||||||
| It uses four intermediate stages and specific Butcher coefficients to achieve | ||||||
| high accuracy while remaining explicit. | ||||||
|
|
||||||
| Attributes | ||||||
| ---------- | ||||||
| a : list[list[float]] | ||||||
| Coefficients of the `a` matrix for intermediate stage coupling. | ||||||
| b : list[float] | ||||||
| Weights for final combination. | ||||||
| c : list[float] | ||||||
| Time positions of intermediate stages. | ||||||
| """ | ||||||
| a = [[0, 0, 0, 0], | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I would set these as tuples in the I would personally instead have a def __init__(self):
a = (...
b = (...
c = (...
super.__init__(a=a, b=b, c=c)
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I did something like that... could you check? |
||||||
| [1/2, 0, 0, 0], | ||||||
| [0, 1/2, 0, 0], | ||||||
| [0, 0, 1, 0]] | ||||||
| b = [1/6, 1/3, 1/3, 1/6] | ||||||
| c = [0, 1/2, 1/2, 1] | ||||||
|
|
||||||
| def __init__(self, *args, **kwargs): | ||||||
| super().__init__(a=self.a, b=self.b, c=self.c) | ||||||
|
|
||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is method a string here, or is it the class for the method? In the latter case, it would remove the need to have the
method_registrymapper. Furthermore, it would allow you to havemethod.resolve(target, sols_temp)here, which is tidierThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's a string. The idea is that the user provides a string to identify which time integrator to apply.