Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions brainpy/integrators/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,10 @@ def state_delays(self, value):
raise ValueError('Cannot set "state_delays" by users.')

def _call_integral(self, *args, **kwargs):
kwargs = dict(kwargs)
t = kwargs.get('t', None)
kwargs['t'] = 0. if t is None else t

if _during_compile:
jaxpr, out_shapes = jax.make_jaxpr(self.integral, return_shape=True)(**kwargs)
outs = jax.core.eval_jaxpr(jaxpr.jaxpr, jaxpr.consts, *jax.tree.leaves(kwargs))
Expand Down
6 changes: 3 additions & 3 deletions brainpy/integrators/ode/explicit_rk.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,8 +178,7 @@ def __init__(self,

def build(self):
# step stage
common.step(self.variables, C.DT,
self.A, self.C, self.code_lines, self.parameters)
common.step(self.variables, C.DT, self.A, self.C, self.code_lines, self.parameters)
# variable update
return_args = common.update(self.variables, C.DT, self.B, self.code_lines)
# returns
Expand All @@ -189,7 +188,8 @@ def build(self):
code_scope={k: v for k, v in self.code_scope.items()},
code_lines=self.code_lines,
show_code=self.show_code,
func_name=self.func_name)
func_name=self.func_name
)


class Euler(ExplicitRKIntegrator):
Expand Down
8 changes: 6 additions & 2 deletions brainpy/math/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
import brainstate.environ
import jax
from jax import config, numpy as jnp, devices
from jax.lib import xla_bridge

from . import modes
from . import scales
Expand Down Expand Up @@ -733,8 +732,13 @@ def clear_buffer_memory(
Clear name cache. Default is True.

"""
if jax.__version_info__ < (0, 8, 0):
from jax.lib.xla_bridge import get_backend
else:
from jax.extend.backend import get_backend

if array:
for buf in xla_bridge.get_backend(platform).live_buffers():
for buf in get_backend(platform).live_buffers():
buf.delete()
if compilation:
jax.clear_caches()
Expand Down
Loading