diff --git a/flowrl/module/bronet.py b/flowrl/module/bronet.py index b1f06dc..1a03151 100644 --- a/flowrl/module/bronet.py +++ b/flowrl/module/bronet.py @@ -39,7 +39,7 @@ def setup(self): @nn.compact def __call__(self, x: jnp.ndarray): - x = nn.Dense(self.hidden_dims, kernel_init=init.orthogonal_init())(x) + x = nn.Dense(self.hidden_dims[0], kernel_init=init.orthogonal_init())(x) x = nn.LayerNorm()(x) x = self.activation(x) for i, size in enumerate(self.hidden_dims): diff --git a/flowrl/module/simba.py b/flowrl/module/simba.py index 57c5de4..74282a4 100644 --- a/flowrl/module/simba.py +++ b/flowrl/module/simba.py @@ -19,7 +19,7 @@ def __call__( ) -> jnp.ndarray: res = x x = nn.LayerNorm()(x) - x = nn.Dense(self.hidden_dim * 4, kernel_init=init.he_normal())(x) + x = nn.Dense(self.hidden_dim * self.multiplier, kernel_init=init.he_normal())(x) x = nn.relu(x) x = nn.Dense(self.hidden_dim, kernel_init=init.he_normal())(x) return res + x