From f6924f1e3f836f8bae6a36dc96821f641167b804 Mon Sep 17 00:00:00 2001 From: Cao Mingjun Date: Sat, 21 Feb 2026 14:27:06 +0800 Subject: [PATCH 1/2] feat: add BroNet, simba and categorical critic --- flowrl/module/bronet.py | 49 +++++++++++++++++++++++++++ flowrl/module/critic.py | 22 ++++++++++++ flowrl/module/initialization.py | 2 +- flowrl/module/simba.py | 60 +++++++++++++++++++++++++++++++++ 4 files changed, 132 insertions(+), 1 deletion(-) create mode 100644 flowrl/module/bronet.py create mode 100644 flowrl/module/simba.py diff --git a/flowrl/module/bronet.py b/flowrl/module/bronet.py new file mode 100644 index 0000000..b1f06dc --- /dev/null +++ b/flowrl/module/bronet.py @@ -0,0 +1,49 @@ +from dataclasses import field + +import flax.linen as nn +import jax.numpy as jnp + +import flowrl.module.initialization as init +from flowrl.types import * + + +class BronetBlock(nn.Module): + hidden_dim: int + activation: Callable[[jnp.ndarray], jnp.ndarray] + + @nn.compact + def __call__( + self, + x: jnp.ndarray, + training: bool = False, + ) -> jnp.ndarray: + res = nn.Dense(self.hidden_dim, kernel_init=init.orthogonal_init())(x) + res = nn.LayerNorm()(res) + res = self.activation(res) + res = nn.Dense(self.hidden_dim, kernel_init=init.orthogonal_init())(res) + res = nn.LayerNorm()(res) + return res + x + +class BroNet(nn.Module): + """ + https://github.com/naumix/BiggerRegularizedCategorical/blob/a98a6378cfa4e875c0cf31de06801ae8cb8b1ed5/jaxrl/networks.py + """ + hidden_dims: Sequence[int] = field(default_factory=lambda: []) + output_dim: int = 0 + activation: Callable = nn.relu + + def setup(self): + assert len(self.hidden_dims) > 0, "hidden_dims must be non-empty" + for i in range(len(self.hidden_dims)): + assert self.hidden_dims[i] == self.hidden_dims[0], "All hidden_dims must be the same for BroNet" + + @nn.compact + def __call__(self, x: jnp.ndarray): + x = nn.Dense(self.hidden_dims, kernel_init=init.orthogonal_init())(x) + x = nn.LayerNorm()(x) + x = self.activation(x) + for i, size in enumerate(self.hidden_dims): + x = BronetBlock(size, self.activation)(x) + if self.output_dim > 0: + x = nn.Dense(self.output_dim, kernel_init=init.orthogonal_init())(x) + return x \ No newline at end of file diff --git a/flowrl/module/critic.py b/flowrl/module/critic.py index db573b6..488269f 100644 --- a/flowrl/module/critic.py +++ b/flowrl/module/critic.py @@ -212,3 +212,25 @@ def __call__( )(x) mean, std = x[..., :1], jax.nn.softplus(x[..., 1:]) return mean, std + +class CategoricalCritic(nn.Module): + backbone: nn.Module + output_nodes: int + kernel_init: Initializer = init.default_kernel_init + bias_init: Initializer = init.default_bias_init + + @nn.compact + def __call__( + self, + obs: jnp.ndarray, + action: Optional[jnp.ndarray] = None, + ) -> jnp.ndarray: + if action is None: + x = obs + else: + x = jnp.concatenate([obs, action], axis=-1) + x = self.backbone(x) + x = nn.Dense( + self.output_nodes, kernel_init=self.kernel_init(), bias_init=self.bias_init() + )(x) + return x \ No newline at end of file diff --git a/flowrl/module/initialization.py b/flowrl/module/initialization.py index 0b51721..ddca408 100644 --- a/flowrl/module/initialization.py +++ b/flowrl/module/initialization.py @@ -1,6 +1,6 @@ import flax.linen as nn import jax -from flax.linen.initializers import lecun_normal, variance_scaling, zeros_init +from flax.linen.initializers import lecun_normal, variance_scaling, zeros_init, he_normal from jax import numpy as jnp from flowrl.types import * diff --git a/flowrl/module/simba.py b/flowrl/module/simba.py new file mode 100644 index 0000000..57c5de4 --- /dev/null +++ b/flowrl/module/simba.py @@ -0,0 +1,60 @@ +from dataclasses import field + +import flax.linen as nn +import jax.numpy as jnp + +import flowrl.module.initialization as init +from flowrl.types import * + + +class SimbaBlock(nn.Module): + hidden_dim: int + multiplier: int = 4 + + @nn.compact + def __call__( + self, + x: jnp.ndarray, + training: bool = False, + ) -> jnp.ndarray: + res = x + x = nn.LayerNorm()(x) + x = nn.Dense(self.hidden_dim * 4, kernel_init=init.he_normal())(x) + x = nn.relu(x) + x = nn.Dense(self.hidden_dim, kernel_init=init.he_normal())(x) + return res + x + +class Simba(nn.Module): + """ + https://arxiv.org/abs/2410.09754 + """ + hidden_dims: Sequence[int] = field(default_factory=lambda: []) + output_dim: int = 0 + multiplier: int = 4 + + def setup(self): + assert len(self.hidden_dims) > 0, "hidden_dims must be non-empty" + for i in range(len(self.hidden_dims)): + assert self.hidden_dims[i] == self.hidden_dims[0], "All hidden_dims must be the same for simba" + + @nn.compact + def __call__( + self, + x: jnp.ndarray, + training: bool = False, + ) -> jnp.ndarray: + # First projection + x = nn.Dense(self.hidden_dims[0], kernel_init=init.orthogonal_init(1))(x) + # Residual blocks + for i, size in enumerate(self.hidden_dims): + x = SimbaBlock( + size, + self.multiplier, + )(x, training) + # Final layer norm + x = nn.LayerNorm()(x) + # Output projection + if self.output_dim > 0: + # No activation here + x = nn.Dense(self.output_dim, kernel_init=init.orthogonal_init(1))(x) + return x From 33e53910e4527e3fed6e12668040aea1c9d57ec6 Mon Sep 17 00:00:00 2001 From: Copilot <198982749+Copilot@users.noreply.github.com> Date: Sat, 21 Feb 2026 02:09:13 -0500 Subject: [PATCH 2/2] fix: correct Dense input dim in BroNet and respect multiplier in SimbaBlock (#23) * Initial plan * fix: correct Dense input dim in BroNet and use multiplier in SimbaBlock Co-authored-by: typoverflow <41679605+typoverflow@users.noreply.github.com> --------- Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: typoverflow <41679605+typoverflow@users.noreply.github.com> --- flowrl/module/bronet.py | 2 +- flowrl/module/simba.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/flowrl/module/bronet.py b/flowrl/module/bronet.py index b1f06dc..1a03151 100644 --- a/flowrl/module/bronet.py +++ b/flowrl/module/bronet.py @@ -39,7 +39,7 @@ def setup(self): @nn.compact def __call__(self, x: jnp.ndarray): - x = nn.Dense(self.hidden_dims, kernel_init=init.orthogonal_init())(x) + x = nn.Dense(self.hidden_dims[0], kernel_init=init.orthogonal_init())(x) x = nn.LayerNorm()(x) x = self.activation(x) for i, size in enumerate(self.hidden_dims): diff --git a/flowrl/module/simba.py b/flowrl/module/simba.py index 57c5de4..74282a4 100644 --- a/flowrl/module/simba.py +++ b/flowrl/module/simba.py @@ -19,7 +19,7 @@ def __call__( ) -> jnp.ndarray: res = x x = nn.LayerNorm()(x) - x = nn.Dense(self.hidden_dim * 4, kernel_init=init.he_normal())(x) + x = nn.Dense(self.hidden_dim * self.multiplier, kernel_init=init.he_normal())(x) x = nn.relu(x) x = nn.Dense(self.hidden_dim, kernel_init=init.he_normal())(x) return res + x