diff --git a/flowrl/module/bronet.py b/flowrl/module/bronet.py new file mode 100644 index 0000000..1a03151 --- /dev/null +++ b/flowrl/module/bronet.py @@ -0,0 +1,49 @@ +from dataclasses import field + +import flax.linen as nn +import jax.numpy as jnp + +import flowrl.module.initialization as init +from flowrl.types import * + + +class BronetBlock(nn.Module): + hidden_dim: int + activation: Callable[[jnp.ndarray], jnp.ndarray] + + @nn.compact + def __call__( + self, + x: jnp.ndarray, + training: bool = False, + ) -> jnp.ndarray: + res = nn.Dense(self.hidden_dim, kernel_init=init.orthogonal_init())(x) + res = nn.LayerNorm()(res) + res = self.activation(res) + res = nn.Dense(self.hidden_dim, kernel_init=init.orthogonal_init())(res) + res = nn.LayerNorm()(res) + return res + x + +class BroNet(nn.Module): + """ + https://github.com/naumix/BiggerRegularizedCategorical/blob/a98a6378cfa4e875c0cf31de06801ae8cb8b1ed5/jaxrl/networks.py + """ + hidden_dims: Sequence[int] = field(default_factory=lambda: []) + output_dim: int = 0 + activation: Callable = nn.relu + + def setup(self): + assert len(self.hidden_dims) > 0, "hidden_dims must be non-empty" + for i in range(len(self.hidden_dims)): + assert self.hidden_dims[i] == self.hidden_dims[0], "All hidden_dims must be the same for BroNet" + + @nn.compact + def __call__(self, x: jnp.ndarray): + x = nn.Dense(self.hidden_dims[0], kernel_init=init.orthogonal_init())(x) + x = nn.LayerNorm()(x) + x = self.activation(x) + for i, size in enumerate(self.hidden_dims): + x = BronetBlock(size, self.activation)(x) + if self.output_dim > 0: + x = nn.Dense(self.output_dim, kernel_init=init.orthogonal_init())(x) + return x \ No newline at end of file diff --git a/flowrl/module/critic.py b/flowrl/module/critic.py index db573b6..488269f 100644 --- a/flowrl/module/critic.py +++ b/flowrl/module/critic.py @@ -212,3 +212,25 @@ def __call__( )(x) mean, std = x[..., :1], jax.nn.softplus(x[..., 1:]) return mean, std + +class CategoricalCritic(nn.Module): + backbone: nn.Module + output_nodes: int + kernel_init: Initializer = init.default_kernel_init + bias_init: Initializer = init.default_bias_init + + @nn.compact + def __call__( + self, + obs: jnp.ndarray, + action: Optional[jnp.ndarray] = None, + ) -> jnp.ndarray: + if action is None: + x = obs + else: + x = jnp.concatenate([obs, action], axis=-1) + x = self.backbone(x) + x = nn.Dense( + self.output_nodes, kernel_init=self.kernel_init(), bias_init=self.bias_init() + )(x) + return x \ No newline at end of file diff --git a/flowrl/module/initialization.py b/flowrl/module/initialization.py index 0b51721..ddca408 100644 --- a/flowrl/module/initialization.py +++ b/flowrl/module/initialization.py @@ -1,6 +1,6 @@ import flax.linen as nn import jax -from flax.linen.initializers import lecun_normal, variance_scaling, zeros_init +from flax.linen.initializers import lecun_normal, variance_scaling, zeros_init, he_normal from jax import numpy as jnp from flowrl.types import * diff --git a/flowrl/module/simba.py b/flowrl/module/simba.py new file mode 100644 index 0000000..74282a4 --- /dev/null +++ b/flowrl/module/simba.py @@ -0,0 +1,60 @@ +from dataclasses import field + +import flax.linen as nn +import jax.numpy as jnp + +import flowrl.module.initialization as init +from flowrl.types import * + + +class SimbaBlock(nn.Module): + hidden_dim: int + multiplier: int = 4 + + @nn.compact + def __call__( + self, + x: jnp.ndarray, + training: bool = False, + ) -> jnp.ndarray: + res = x + x = nn.LayerNorm()(x) + x = nn.Dense(self.hidden_dim * self.multiplier, kernel_init=init.he_normal())(x) + x = nn.relu(x) + x = nn.Dense(self.hidden_dim, kernel_init=init.he_normal())(x) + return res + x + +class Simba(nn.Module): + """ + https://arxiv.org/abs/2410.09754 + """ + hidden_dims: Sequence[int] = field(default_factory=lambda: []) + output_dim: int = 0 + multiplier: int = 4 + + def setup(self): + assert len(self.hidden_dims) > 0, "hidden_dims must be non-empty" + for i in range(len(self.hidden_dims)): + assert self.hidden_dims[i] == self.hidden_dims[0], "All hidden_dims must be the same for simba" + + @nn.compact + def __call__( + self, + x: jnp.ndarray, + training: bool = False, + ) -> jnp.ndarray: + # First projection + x = nn.Dense(self.hidden_dims[0], kernel_init=init.orthogonal_init(1))(x) + # Residual blocks + for i, size in enumerate(self.hidden_dims): + x = SimbaBlock( + size, + self.multiplier, + )(x, training) + # Final layer norm + x = nn.LayerNorm()(x) + # Output projection + if self.output_dim > 0: + # No activation here + x = nn.Dense(self.output_dim, kernel_init=init.orthogonal_init(1))(x) + return x