From 252ffe27e217393ab578d6044e180899b31d484c Mon Sep 17 00:00:00 2001 From: Faezeh Habibi <155960330+Faezehabibi@users.noreply.github.com> Date: Mon, 23 Mar 2026 16:45:01 -0400 Subject: [PATCH 1/5] Modify reset method to accept batch_size parameter for flexible test set size --- .../components/synapses/patched/hebbianPatchedSynapse.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/ngclearn/components/synapses/patched/hebbianPatchedSynapse.py b/ngclearn/components/synapses/patched/hebbianPatchedSynapse.py index a82ccf6a..a67a4c1d 100644 --- a/ngclearn/components/synapses/patched/hebbianPatchedSynapse.py +++ b/ngclearn/components/synapses/patched/hebbianPatchedSynapse.py @@ -281,9 +281,9 @@ def evolve(self): self.dBiases.set(dBiases) @compilable - def reset(self): - preVals = jnp.zeros((self.batch_size, self.shape[0])) - postVals = jnp.zeros((self.batch_size, self.shape[1])) + def reset(self, batch_size): + preVals = jnp.zeros((batch_size, self.shape[0])) + postVals = jnp.zeros((batch_size, self.shape[1])) # BUG: the self.inputs here does not have the targeted field # NOTE: Quick workaround is to check if targeted is in the input or not hasattr(self.inputs, "targeted") and not self.inputs.targeted and self.inputs.set(preVals) # inputs From 4f3065ba2fcfb1dd4e7dc0a9862774b2d8e72b2d Mon Sep 17 00:00:00 2001 From: Faezeh Habibi <155960330+Faezehabibi@users.noreply.github.com> Date: Mon, 23 Mar 2026 16:46:37 -0400 Subject: [PATCH 2/5] Modify reset method to accept batch_size parameter --- ngclearn/components/synapses/patched/patchedSynapse.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/ngclearn/components/synapses/patched/patchedSynapse.py b/ngclearn/components/synapses/patched/patchedSynapse.py index 215cc278..1d98cf29 100644 --- a/ngclearn/components/synapses/patched/patchedSynapse.py +++ b/ngclearn/components/synapses/patched/patchedSynapse.py @@ -167,9 +167,9 @@ def advance_state(self): self.pre_out.set(pre_out) @compilable - def reset(self): - preVals = jnp.zeros((self.batch_size, self.shape[0])) - postVals = jnp.zeros((self.batch_size, self.shape[1])) + def reset(self, batch_size): + preVals = jnp.zeros((batch_size, self.shape[0])) + postVals = jnp.zeros((batch_size, self.shape[1])) # BUG: the self.inputs here does not have the targeted field # NOTE: Quick workaround is to check if targeted is in the input or not From 03d375f17c5a19fcbd666e716ba726942ad8d904 Mon Sep 17 00:00:00 2001 From: Faezeh Habibi <155960330+Faezehabibi@users.noreply.github.com> Date: Mon, 23 Mar 2026 16:47:53 -0400 Subject: [PATCH 3/5] Refactor RateCell class reset function for flexible batch size --- ngclearn/components/neurons/graded/rateCell.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/ngclearn/components/neurons/graded/rateCell.py b/ngclearn/components/neurons/graded/rateCell.py index 3cf50a22..c2628a1e 100755 --- a/ngclearn/components/neurons/graded/rateCell.py +++ b/ngclearn/components/neurons/graded/rateCell.py @@ -252,10 +252,10 @@ def advance_state(self, dt): self.zF.set(zF) @compilable - def reset(self): #, batch_size, shape): #n_units - _shape = (self.batch_size, self.shape[0]) + def reset(self, batch_size): + _shape = (batch_size, self.shape[0]) if len(self.shape) > 1: - _shape = (self.batch_size, self.shape[0], self.shape[1], self.shape[2]) + _shape = (batch_size, self.shape[0], self.shape[1], self.shape[2]) restVals = jnp.zeros(_shape) self.j.set(restVals) self.j_td.set(restVals) From 17107497ea87feab631fc1f39ab6f1f4de50d911 Mon Sep 17 00:00:00 2001 From: Faezeh Habibi <155960330+Faezehabibi@users.noreply.github.com> Date: Mon, 23 Mar 2026 16:48:48 -0400 Subject: [PATCH 4/5] Refactor GaussianErrorCell class functions for flexible batch size --- ngclearn/components/neurons/graded/gaussianErrorCell.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/ngclearn/components/neurons/graded/gaussianErrorCell.py b/ngclearn/components/neurons/graded/gaussianErrorCell.py index 776dad46..a3d0a1eb 100755 --- a/ngclearn/components/neurons/graded/gaussianErrorCell.py +++ b/ngclearn/components/neurons/graded/gaussianErrorCell.py @@ -48,7 +48,6 @@ def __init__(self, name, n_units, batch_size=1, sigma=1., shape=None, **kwargs): self.sigma_shape = sigma_shape self.shape = shape self.n_units = n_units - self.batch_size = batch_size ## Convolution shape setup self.width = self.height = n_units @@ -108,10 +107,10 @@ def advance_state(self, dt): ## compute Gaussian error cell output # @transition(output_compartments=["dmu", "dtarget", "dSigma", "target", "mu", "modulator", "L", "mask"]) # @staticmethod @compilable - def reset(self): ## reset core components/statistics - _shape = (self.batch_size, self.shape[0]) + def reset(self, batch_size): ## reset core components/statistics + _shape = (batch_size, self.shape[0]) if len(self.shape) > 1: - _shape = (self.batch_size, self.shape[0], self.shape[1], self.shape[2]) + _shape = (batch_size, self.shape[0], self.shape[1], self.shape[2]) restVals = jnp.zeros(_shape) dmu = restVals dtarget = restVals From 565e643f5b3660e5a3c9d15844ec0ed2f2d55e34 Mon Sep 17 00:00:00 2001 From: Faezeh Habibi <155960330+Faezehabibi@users.noreply.github.com> Date: Mon, 23 Mar 2026 16:49:55 -0400 Subject: [PATCH 5/5] flexible batch_size --- ngclearn/components/input_encoders/ganglionCell.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/ngclearn/components/input_encoders/ganglionCell.py b/ngclearn/components/input_encoders/ganglionCell.py index 5f707280..b1b223ed 100644 --- a/ngclearn/components/input_encoders/ganglionCell.py +++ b/ngclearn/components/input_encoders/ganglionCell.py @@ -131,7 +131,6 @@ def __init__(self, name: str, self.n_cells = n_cells self.sigma = sigma - self.batch_size = batch_size self.area_shape = area_shape self.patch_shape = patch_shape self.step_shape = step_shape @@ -144,10 +143,10 @@ def __init__(self, name: str, filter = create_dog_filter(patch_shape=self.patch_shape, sigma=sigma) # ═════════════════ compartments initial values ════════════════════ - in_restVals = jnp.zeros((self.batch_size, + in_restVals = jnp.zeros((batch_size, *self.area_shape)) ## input: (B | ix | iy) - out_restVals = jnp.zeros((self.batch_size, ## output.shape: (B | n_cells * px * py) + out_restVals = jnp.zeros((batch_size, ## output.shape: (B | n_cells * px * py) self.n_cells * self.patch_shape[0] * self.patch_shape[1])) # ═══════════════════ set compartments ══════════════════════ @@ -176,11 +175,11 @@ def advance_state(self, t): self.outputs.set(outputs) @compilable - def reset(self): - in_restVals = jnp.zeros((self.batch_size, + def reset(self, batch_size): + in_restVals = jnp.zeros((batch_size, *self.area_shape)) ## input: (B | ix | iy) - out_restVals = jnp.zeros((self.batch_size, ## output.shape: (B | n_cells * px * py) + out_restVals = jnp.zeros((batch_size, ## output.shape: (B | n_cells * px * py) self.n_cells * self.patch_shape[0] * self.patch_shape[1])) self.inputs.set(in_restVals)