-
Notifications
You must be signed in to change notification settings - Fork 102
Fix ODE integrator default time and compatibility with jax>=0.9.0
#813
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…format code in build method
Reviewer's guide (collapsed on small PRs)Reviewer's GuideEnsures the ODE integrator’s internal _call_integral method always passes a default t=0 keyword argument when none is provided, and applies minor formatting cleanups in the explicit Runge-Kutta integrator implementation. Sequence diagram for _call_integral with default t keyword argumentsequenceDiagram
participant Caller
participant IntegratorBase
participant JAX as jax
participant JAXCore as jax_core
Caller->>IntegratorBase: _call_integral(args, kwargs)
IntegratorBase->>IntegratorBase: copy kwargs to new dict
IntegratorBase->>IntegratorBase: t = kwargs.get(t, None)
alt t is None
IntegratorBase->>IntegratorBase: kwargs[t] = 0.0
else t is provided
IntegratorBase->>IntegratorBase: kwargs[t] = t
end
alt _during_compile is True
IntegratorBase->>JAX: make_jaxpr(integral, return_shape=True)(**kwargs)
JAX-->>IntegratorBase: jaxpr, out_shapes
IntegratorBase->>JAXCore: eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, *leaves(kwargs))
JAXCore-->>IntegratorBase: outs
else _during_compile is False
IntegratorBase->>IntegratorBase: integral(**kwargs)
IntegratorBase-->>Caller: result
end
Updated class diagram for IntegratorBase and ExplicitRKIntegratorclassDiagram
class IntegratorBase {
<<abstract>>
+integral
+_call_integral(*args, **kwargs)
}
class ExplicitRKIntegrator {
+variables
+A
+B
+C
+code_lines
+parameters
+code_scope
+show_code
+func_name
+build()
}
IntegratorBase <|-- ExplicitRKIntegrator
class CommonModule {
+step(variables, dt, A, C, code_lines, parameters)
+update(variables, dt, B, code_lines)
}
class CodeGenerator {
+compile(var_dict, arg_names, return_args, fun_name, code_scope, code_lines, show_code, func_name)
}
ExplicitRKIntegrator ..> CommonModule : uses
ExplicitRKIntegrator ..> CodeGenerator : uses
File-Level Changes
Tips and commandsInteracting with Sourcery
Customizing Your ExperienceAccess your dashboard to:
Getting Help
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey - I've left some high level feedback:
- In
_call_integral, consider defaultingtat the call sites (or only when the target integral actually expects atkwarg) rather than unconditionally injectingtintokwargs, to avoid surprising failures for integrals that don’t accept atparameter.
Prompt for AI Agents
Please address the comments from this code review:
## Overall Comments
- In `_call_integral`, consider defaulting `t` at the call sites (or only when the target integral actually expects a `t` kwarg) rather than unconditionally injecting `t` into `kwargs`, to avoid surprising failures for integrals that don’t accept a `t` parameter.Help me be more useful! Please click 👍 or 👎 on each comment and I'll use the feedback to improve your reviews.
|
@sourcery-ai title |
jax>=0.9.0
Summary by Sourcery
Ensure the ODE integrator always passes a default time argument and perform minor formatting cleanups.
Bug Fixes:
Enhancements: