From 3e7eb8fbdcb7b561dccda80301c62e5f1ff10aed Mon Sep 17 00:00:00 2001 From: Chaoming Wang Date: Wed, 21 Jan 2026 17:07:22 +0800 Subject: [PATCH 1/2] fix: ensure 't' keyword argument defaults to 0 in _call_integral and format code in build method --- brainpy/integrators/base.py | 4 ++++ brainpy/integrators/ode/explicit_rk.py | 6 +++--- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/brainpy/integrators/base.py b/brainpy/integrators/base.py index fbc15c6c..bdf444f7 100644 --- a/brainpy/integrators/base.py +++ b/brainpy/integrators/base.py @@ -141,6 +141,10 @@ def state_delays(self, value): raise ValueError('Cannot set "state_delays" by users.') def _call_integral(self, *args, **kwargs): + kwargs = dict(kwargs) + t = kwargs.get('t', None) + kwargs['t'] = 0. if t is None else t + if _during_compile: jaxpr, out_shapes = jax.make_jaxpr(self.integral, return_shape=True)(**kwargs) outs = jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, *jax.tree.leaves(kwargs)) diff --git a/brainpy/integrators/ode/explicit_rk.py b/brainpy/integrators/ode/explicit_rk.py index 59e76994..9b425a91 100644 --- a/brainpy/integrators/ode/explicit_rk.py +++ b/brainpy/integrators/ode/explicit_rk.py @@ -178,8 +178,7 @@ def __init__(self, def build(self): # step stage - common.step(self.variables, C.DT, - self.A, self.C, self.code_lines, self.parameters) + common.step(self.variables, C.DT, self.A, self.C, self.code_lines, self.parameters) # variable update return_args = common.update(self.variables, C.DT, self.B, self.code_lines) # returns @@ -189,7 +188,8 @@ def build(self): code_scope={k: v for k, v in self.code_scope.items()}, code_lines=self.code_lines, show_code=self.show_code, - func_name=self.func_name) + func_name=self.func_name + ) class Euler(ExplicitRKIntegrator): From 2995effa8aaa4812dde3821d43f74c28a7c75460 Mon Sep 17 00:00:00 2001 From: Chaoming Wang Date: Wed, 21 Jan 2026 17:11:24 +0800 Subject: [PATCH 2/2] fix: update backend import for compatibility with jax>=0.8.0 --- brainpy/math/environment.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/brainpy/math/environment.py b/brainpy/math/environment.py index e9cb6ac9..ca258ce2 100644 --- a/brainpy/math/environment.py +++ b/brainpy/math/environment.py @@ -25,7 +25,6 @@ import brainstate.environ import jax from jax import config, numpy as jnp, devices -from jax.lib import xla_bridge from . import modes from . import scales @@ -733,8 +732,13 @@ def clear_buffer_memory( Clear name cache. Default is True. """ + if jax.__version_info__ < (0, 8, 0): + from jax.lib.xla_bridge import get_backend + else: + from jax.extend.backend import get_backend + if array: - for buf in xla_bridge.get_backend(platform).live_buffers(): + for buf in get_backend(platform).live_buffers(): buf.delete() if compilation: jax.clear_caches()