From 82ded2494b626df7f1076d2d4013cd31a8e7b3e2 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sat, 21 Feb 2026 06:43:55 +0000 Subject: [PATCH 1/2] Initial plan From 201059e5baa792bd9b8d335caebf2046c0f9d990 Mon Sep 17 00:00:00 2001 From: "copilot-swe-agent[bot]" <198982749+Copilot@users.noreply.github.com> Date: Sat, 21 Feb 2026 06:48:21 +0000 Subject: [PATCH 2/2] fix: correct Dense input dim in BroNet and use multiplier in SimbaBlock Co-authored-by: typoverflow <41679605+typoverflow@users.noreply.github.com> --- flowrl/module/bronet.py | 2 +- flowrl/module/simba.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/flowrl/module/bronet.py b/flowrl/module/bronet.py index b1f06dc..1a03151 100644 --- a/flowrl/module/bronet.py +++ b/flowrl/module/bronet.py @@ -39,7 +39,7 @@ def setup(self): @nn.compact def __call__(self, x: jnp.ndarray): - x = nn.Dense(self.hidden_dims, kernel_init=init.orthogonal_init())(x) + x = nn.Dense(self.hidden_dims[0], kernel_init=init.orthogonal_init())(x) x = nn.LayerNorm()(x) x = self.activation(x) for i, size in enumerate(self.hidden_dims): diff --git a/flowrl/module/simba.py b/flowrl/module/simba.py index 57c5de4..74282a4 100644 --- a/flowrl/module/simba.py +++ b/flowrl/module/simba.py @@ -19,7 +19,7 @@ def __call__( ) -> jnp.ndarray: res = x x = nn.LayerNorm()(x) - x = nn.Dense(self.hidden_dim * 4, kernel_init=init.he_normal())(x) + x = nn.Dense(self.hidden_dim * self.multiplier, kernel_init=init.he_normal())(x) x = nn.relu(x) x = nn.Dense(self.hidden_dim, kernel_init=init.he_normal())(x) return res + x