diff --git a/ngclearn/components/input_encoders/ganglionCell.py b/ngclearn/components/input_encoders/ganglionCell.py index 5f707280..b1b223ed 100644 --- a/ngclearn/components/input_encoders/ganglionCell.py +++ b/ngclearn/components/input_encoders/ganglionCell.py @@ -131,7 +131,6 @@ def __init__(self, name: str, self.n_cells = n_cells self.sigma = sigma - self.batch_size = batch_size self.area_shape = area_shape self.patch_shape = patch_shape self.step_shape = step_shape @@ -144,10 +143,10 @@ def __init__(self, name: str, filter = create_dog_filter(patch_shape=self.patch_shape, sigma=sigma) # ═════════════════ compartments initial values ════════════════════ - in_restVals = jnp.zeros((self.batch_size, + in_restVals = jnp.zeros((batch_size, *self.area_shape)) ## input: (B | ix | iy) - out_restVals = jnp.zeros((self.batch_size, ## output.shape: (B | n_cells * px * py) + out_restVals = jnp.zeros((batch_size, ## output.shape: (B | n_cells * px * py) self.n_cells * self.patch_shape[0] * self.patch_shape[1])) # ═══════════════════ set compartments ══════════════════════ @@ -176,11 +175,11 @@ def advance_state(self, t): self.outputs.set(outputs) @compilable - def reset(self): - in_restVals = jnp.zeros((self.batch_size, + def reset(self, batch_size): + in_restVals = jnp.zeros((batch_size, *self.area_shape)) ## input: (B | ix | iy) - out_restVals = jnp.zeros((self.batch_size, ## output.shape: (B | n_cells * px * py) + out_restVals = jnp.zeros((batch_size, ## output.shape: (B | n_cells * px * py) self.n_cells * self.patch_shape[0] * self.patch_shape[1])) self.inputs.set(in_restVals) diff --git a/ngclearn/components/neurons/graded/gaussianErrorCell.py b/ngclearn/components/neurons/graded/gaussianErrorCell.py index 776dad46..a3d0a1eb 100755 --- a/ngclearn/components/neurons/graded/gaussianErrorCell.py +++ b/ngclearn/components/neurons/graded/gaussianErrorCell.py @@ -48,7 +48,6 @@ def __init__(self, name, n_units, batch_size=1, sigma=1., shape=None, **kwargs): self.sigma_shape = sigma_shape self.shape = shape self.n_units = n_units - self.batch_size = batch_size ## Convolution shape setup self.width = self.height = n_units @@ -108,10 +107,10 @@ def advance_state(self, dt): ## compute Gaussian error cell output # @transition(output_compartments=["dmu", "dtarget", "dSigma", "target", "mu", "modulator", "L", "mask"]) # @staticmethod @compilable - def reset(self): ## reset core components/statistics - _shape = (self.batch_size, self.shape[0]) + def reset(self, batch_size): ## reset core components/statistics + _shape = (batch_size, self.shape[0]) if len(self.shape) > 1: - _shape = (self.batch_size, self.shape[0], self.shape[1], self.shape[2]) + _shape = (batch_size, self.shape[0], self.shape[1], self.shape[2]) restVals = jnp.zeros(_shape) dmu = restVals dtarget = restVals diff --git a/ngclearn/components/neurons/graded/rateCell.py b/ngclearn/components/neurons/graded/rateCell.py index 3cf50a22..c2628a1e 100755 --- a/ngclearn/components/neurons/graded/rateCell.py +++ b/ngclearn/components/neurons/graded/rateCell.py @@ -252,10 +252,10 @@ def advance_state(self, dt): self.zF.set(zF) @compilable - def reset(self): #, batch_size, shape): #n_units - _shape = (self.batch_size, self.shape[0]) + def reset(self, batch_size): + _shape = (batch_size, self.shape[0]) if len(self.shape) > 1: - _shape = (self.batch_size, self.shape[0], self.shape[1], self.shape[2]) + _shape = (batch_size, self.shape[0], self.shape[1], self.shape[2]) restVals = jnp.zeros(_shape) self.j.set(restVals) self.j_td.set(restVals) diff --git a/ngclearn/components/synapses/patched/hebbianPatchedSynapse.py b/ngclearn/components/synapses/patched/hebbianPatchedSynapse.py index a82ccf6a..a67a4c1d 100644 --- a/ngclearn/components/synapses/patched/hebbianPatchedSynapse.py +++ b/ngclearn/components/synapses/patched/hebbianPatchedSynapse.py @@ -281,9 +281,9 @@ def evolve(self): self.dBiases.set(dBiases) @compilable - def reset(self): - preVals = jnp.zeros((self.batch_size, self.shape[0])) - postVals = jnp.zeros((self.batch_size, self.shape[1])) + def reset(self, batch_size): + preVals = jnp.zeros((batch_size, self.shape[0])) + postVals = jnp.zeros((batch_size, self.shape[1])) # BUG: the self.inputs here does not have the targeted field # NOTE: Quick workaround is to check if targeted is in the input or not hasattr(self.inputs, "targeted") and not self.inputs.targeted and self.inputs.set(preVals) # inputs diff --git a/ngclearn/components/synapses/patched/patchedSynapse.py b/ngclearn/components/synapses/patched/patchedSynapse.py index 215cc278..1d98cf29 100644 --- a/ngclearn/components/synapses/patched/patchedSynapse.py +++ b/ngclearn/components/synapses/patched/patchedSynapse.py @@ -167,9 +167,9 @@ def advance_state(self): self.pre_out.set(pre_out) @compilable - def reset(self): - preVals = jnp.zeros((self.batch_size, self.shape[0])) - postVals = jnp.zeros((self.batch_size, self.shape[1])) + def reset(self, batch_size): + preVals = jnp.zeros((batch_size, self.shape[0])) + postVals = jnp.zeros((batch_size, self.shape[1])) # BUG: the self.inputs here does not have the targeted field # NOTE: Quick workaround is to check if targeted is in the input or not