Skip to content

Commit c049d44

Browse files
cmj2002Copilottypoverflow
authored
feat: add BroNet, SimBa and categorical critic (#21)
* feat: add BroNet, simba and categorical critic * fix: correct Dense input dim in BroNet and respect multiplier in SimbaBlock (#23) * Initial plan * fix: correct Dense input dim in BroNet and use multiplier in SimbaBlock Co-authored-by: typoverflow <41679605+typoverflow@users.noreply.github.com> --------- Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: typoverflow <41679605+typoverflow@users.noreply.github.com> --------- Co-authored-by: Copilot <198982749+Copilot@users.noreply.github.com> Co-authored-by: typoverflow <41679605+typoverflow@users.noreply.github.com>
1 parent ebda0a7 commit c049d44

4 files changed

Lines changed: 132 additions & 1 deletion

File tree

flowrl/module/bronet.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
from dataclasses import field
2+
3+
import flax.linen as nn
4+
import jax.numpy as jnp
5+
6+
import flowrl.module.initialization as init
7+
from flowrl.types import *
8+
9+
10+
class BronetBlock(nn.Module):
11+
hidden_dim: int
12+
activation: Callable[[jnp.ndarray], jnp.ndarray]
13+
14+
@nn.compact
15+
def __call__(
16+
self,
17+
x: jnp.ndarray,
18+
training: bool = False,
19+
) -> jnp.ndarray:
20+
res = nn.Dense(self.hidden_dim, kernel_init=init.orthogonal_init())(x)
21+
res = nn.LayerNorm()(res)
22+
res = self.activation(res)
23+
res = nn.Dense(self.hidden_dim, kernel_init=init.orthogonal_init())(res)
24+
res = nn.LayerNorm()(res)
25+
return res + x
26+
27+
class BroNet(nn.Module):
28+
"""
29+
https://github.com/naumix/BiggerRegularizedCategorical/blob/a98a6378cfa4e875c0cf31de06801ae8cb8b1ed5/jaxrl/networks.py
30+
"""
31+
hidden_dims: Sequence[int] = field(default_factory=lambda: [])
32+
output_dim: int = 0
33+
activation: Callable = nn.relu
34+
35+
def setup(self):
36+
assert len(self.hidden_dims) > 0, "hidden_dims must be non-empty"
37+
for i in range(len(self.hidden_dims)):
38+
assert self.hidden_dims[i] == self.hidden_dims[0], "All hidden_dims must be the same for BroNet"
39+
40+
@nn.compact
41+
def __call__(self, x: jnp.ndarray):
42+
x = nn.Dense(self.hidden_dims[0], kernel_init=init.orthogonal_init())(x)
43+
x = nn.LayerNorm()(x)
44+
x = self.activation(x)
45+
for i, size in enumerate(self.hidden_dims):
46+
x = BronetBlock(size, self.activation)(x)
47+
if self.output_dim > 0:
48+
x = nn.Dense(self.output_dim, kernel_init=init.orthogonal_init())(x)
49+
return x

flowrl/module/critic.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -297,3 +297,25 @@ def __call__(
297297
)(x)
298298
mean, std = x[..., :1], jax.nn.softplus(x[..., 1:])
299299
return mean, std
300+
301+
class CategoricalCritic(nn.Module):
302+
backbone: nn.Module
303+
output_nodes: int
304+
kernel_init: Initializer = init.default_kernel_init
305+
bias_init: Initializer = init.default_bias_init
306+
307+
@nn.compact
308+
def __call__(
309+
self,
310+
obs: jnp.ndarray,
311+
action: Optional[jnp.ndarray] = None,
312+
) -> jnp.ndarray:
313+
if action is None:
314+
x = obs
315+
else:
316+
x = jnp.concatenate([obs, action], axis=-1)
317+
x = self.backbone(x)
318+
x = nn.Dense(
319+
self.output_nodes, kernel_init=self.kernel_init(), bias_init=self.bias_init()
320+
)(x)
321+
return x

flowrl/module/initialization.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import flax.linen as nn
22
import jax
3-
from flax.linen.initializers import lecun_normal, variance_scaling, zeros_init
3+
from flax.linen.initializers import lecun_normal, variance_scaling, zeros_init, he_normal
44
from jax import numpy as jnp
55

66
from flowrl.types import *

flowrl/module/simba.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
from dataclasses import field
2+
3+
import flax.linen as nn
4+
import jax.numpy as jnp
5+
6+
import flowrl.module.initialization as init
7+
from flowrl.types import *
8+
9+
10+
class SimbaBlock(nn.Module):
11+
hidden_dim: int
12+
multiplier: int = 4
13+
14+
@nn.compact
15+
def __call__(
16+
self,
17+
x: jnp.ndarray,
18+
training: bool = False,
19+
) -> jnp.ndarray:
20+
res = x
21+
x = nn.LayerNorm()(x)
22+
x = nn.Dense(self.hidden_dim * self.multiplier, kernel_init=init.he_normal())(x)
23+
x = nn.relu(x)
24+
x = nn.Dense(self.hidden_dim, kernel_init=init.he_normal())(x)
25+
return res + x
26+
27+
class Simba(nn.Module):
28+
"""
29+
https://arxiv.org/abs/2410.09754
30+
"""
31+
hidden_dims: Sequence[int] = field(default_factory=lambda: [])
32+
output_dim: int = 0
33+
multiplier: int = 4
34+
35+
def setup(self):
36+
assert len(self.hidden_dims) > 0, "hidden_dims must be non-empty"
37+
for i in range(len(self.hidden_dims)):
38+
assert self.hidden_dims[i] == self.hidden_dims[0], "All hidden_dims must be the same for simba"
39+
40+
@nn.compact
41+
def __call__(
42+
self,
43+
x: jnp.ndarray,
44+
training: bool = False,
45+
) -> jnp.ndarray:
46+
# First projection
47+
x = nn.Dense(self.hidden_dims[0], kernel_init=init.orthogonal_init(1))(x)
48+
# Residual blocks
49+
for i, size in enumerate(self.hidden_dims):
50+
x = SimbaBlock(
51+
size,
52+
self.multiplier,
53+
)(x, training)
54+
# Final layer norm
55+
x = nn.LayerNorm()(x)
56+
# Output projection
57+
if self.output_dim > 0:
58+
# No activation here
59+
x = nn.Dense(self.output_dim, kernel_init=init.orthogonal_init(1))(x)
60+
return x

0 commit comments

Comments
 (0)