Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 5 additions & 6 deletions ngclearn/components/input_encoders/ganglionCell.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,6 @@ def __init__(self, name: str,
self.n_cells = n_cells
self.sigma = sigma

self.batch_size = batch_size
self.area_shape = area_shape
self.patch_shape = patch_shape
self.step_shape = step_shape
Expand All @@ -144,10 +143,10 @@ def __init__(self, name: str,
filter = create_dog_filter(patch_shape=self.patch_shape, sigma=sigma)

# ═════════════════ compartments initial values ════════════════════
in_restVals = jnp.zeros((self.batch_size,
in_restVals = jnp.zeros((batch_size,
*self.area_shape)) ## input: (B | ix | iy)

out_restVals = jnp.zeros((self.batch_size, ## output.shape: (B | n_cells * px * py)
out_restVals = jnp.zeros((batch_size, ## output.shape: (B | n_cells * px * py)
self.n_cells * self.patch_shape[0] * self.patch_shape[1]))

# ═══════════════════ set compartments ══════════════════════
Expand Down Expand Up @@ -176,11 +175,11 @@ def advance_state(self, t):
self.outputs.set(outputs)

@compilable
def reset(self):
in_restVals = jnp.zeros((self.batch_size,
def reset(self, batch_size):
in_restVals = jnp.zeros((batch_size,
*self.area_shape)) ## input: (B | ix | iy)

out_restVals = jnp.zeros((self.batch_size, ## output.shape: (B | n_cells * px * py)
out_restVals = jnp.zeros((batch_size, ## output.shape: (B | n_cells * px * py)
self.n_cells * self.patch_shape[0] * self.patch_shape[1]))

self.inputs.set(in_restVals)
Expand Down
7 changes: 3 additions & 4 deletions ngclearn/components/neurons/graded/gaussianErrorCell.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ def __init__(self, name, n_units, batch_size=1, sigma=1., shape=None, **kwargs):
self.sigma_shape = sigma_shape
self.shape = shape
self.n_units = n_units
self.batch_size = batch_size

## Convolution shape setup
self.width = self.height = n_units
Expand Down Expand Up @@ -108,10 +107,10 @@ def advance_state(self, dt): ## compute Gaussian error cell output
# @transition(output_compartments=["dmu", "dtarget", "dSigma", "target", "mu", "modulator", "L", "mask"])
# @staticmethod
@compilable
def reset(self): ## reset core components/statistics
_shape = (self.batch_size, self.shape[0])
def reset(self, batch_size): ## reset core components/statistics
_shape = (batch_size, self.shape[0])
if len(self.shape) > 1:
_shape = (self.batch_size, self.shape[0], self.shape[1], self.shape[2])
_shape = (batch_size, self.shape[0], self.shape[1], self.shape[2])
restVals = jnp.zeros(_shape)
dmu = restVals
dtarget = restVals
Expand Down
6 changes: 3 additions & 3 deletions ngclearn/components/neurons/graded/rateCell.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,10 +252,10 @@ def advance_state(self, dt):
self.zF.set(zF)

@compilable
def reset(self): #, batch_size, shape): #n_units
_shape = (self.batch_size, self.shape[0])
def reset(self, batch_size):
_shape = (batch_size, self.shape[0])
if len(self.shape) > 1:
_shape = (self.batch_size, self.shape[0], self.shape[1], self.shape[2])
_shape = (batch_size, self.shape[0], self.shape[1], self.shape[2])
restVals = jnp.zeros(_shape)
self.j.set(restVals)
self.j_td.set(restVals)
Expand Down
6 changes: 3 additions & 3 deletions ngclearn/components/synapses/patched/hebbianPatchedSynapse.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,9 +281,9 @@ def evolve(self):
self.dBiases.set(dBiases)

@compilable
def reset(self):
preVals = jnp.zeros((self.batch_size, self.shape[0]))
postVals = jnp.zeros((self.batch_size, self.shape[1]))
def reset(self, batch_size):
preVals = jnp.zeros((batch_size, self.shape[0]))
postVals = jnp.zeros((batch_size, self.shape[1]))
# BUG: the self.inputs here does not have the targeted field
# NOTE: Quick workaround is to check if targeted is in the input or not
hasattr(self.inputs, "targeted") and not self.inputs.targeted and self.inputs.set(preVals) # inputs
Expand Down
6 changes: 3 additions & 3 deletions ngclearn/components/synapses/patched/patchedSynapse.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,9 +167,9 @@ def advance_state(self):
self.pre_out.set(pre_out)

@compilable
def reset(self):
preVals = jnp.zeros((self.batch_size, self.shape[0]))
postVals = jnp.zeros((self.batch_size, self.shape[1]))
def reset(self, batch_size):
preVals = jnp.zeros((batch_size, self.shape[0]))
postVals = jnp.zeros((batch_size, self.shape[1]))

# BUG: the self.inputs here does not have the targeted field
# NOTE: Quick workaround is to check if targeted is in the input or not
Expand Down