Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 49 additions & 0 deletions flowrl/module/bronet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
from dataclasses import field

import flax.linen as nn
import jax.numpy as jnp

import flowrl.module.initialization as init
from flowrl.types import *


class BronetBlock(nn.Module):
hidden_dim: int
activation: Callable[[jnp.ndarray], jnp.ndarray]

@nn.compact
def __call__(
self,
x: jnp.ndarray,
training: bool = False,
) -> jnp.ndarray:
res = nn.Dense(self.hidden_dim, kernel_init=init.orthogonal_init())(x)
res = nn.LayerNorm()(res)
res = self.activation(res)
res = nn.Dense(self.hidden_dim, kernel_init=init.orthogonal_init())(res)
res = nn.LayerNorm()(res)
return res + x

class BroNet(nn.Module):
"""
https://github.com/naumix/BiggerRegularizedCategorical/blob/a98a6378cfa4e875c0cf31de06801ae8cb8b1ed5/jaxrl/networks.py
"""
hidden_dims: Sequence[int] = field(default_factory=lambda: [])
output_dim: int = 0
activation: Callable = nn.relu

def setup(self):
assert len(self.hidden_dims) > 0, "hidden_dims must be non-empty"
for i in range(len(self.hidden_dims)):
assert self.hidden_dims[i] == self.hidden_dims[0], "All hidden_dims must be the same for BroNet"

@nn.compact
def __call__(self, x: jnp.ndarray):
x = nn.Dense(self.hidden_dims[0], kernel_init=init.orthogonal_init())(x)
x = nn.LayerNorm()(x)
x = self.activation(x)
for i, size in enumerate(self.hidden_dims):
x = BronetBlock(size, self.activation)(x)
if self.output_dim > 0:
x = nn.Dense(self.output_dim, kernel_init=init.orthogonal_init())(x)
return x
22 changes: 22 additions & 0 deletions flowrl/module/critic.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,3 +212,25 @@ def __call__(
)(x)
mean, std = x[..., :1], jax.nn.softplus(x[..., 1:])
return mean, std

class CategoricalCritic(nn.Module):
backbone: nn.Module
output_nodes: int
kernel_init: Initializer = init.default_kernel_init
bias_init: Initializer = init.default_bias_init

@nn.compact
def __call__(
self,
obs: jnp.ndarray,
action: Optional[jnp.ndarray] = None,
) -> jnp.ndarray:
if action is None:
x = obs
else:
x = jnp.concatenate([obs, action], axis=-1)
x = self.backbone(x)
x = nn.Dense(
self.output_nodes, kernel_init=self.kernel_init(), bias_init=self.bias_init()
)(x)
return x
2 changes: 1 addition & 1 deletion flowrl/module/initialization.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import flax.linen as nn
import jax
from flax.linen.initializers import lecun_normal, variance_scaling, zeros_init
from flax.linen.initializers import lecun_normal, variance_scaling, zeros_init, he_normal
from jax import numpy as jnp

from flowrl.types import *
Expand Down
60 changes: 60 additions & 0 deletions flowrl/module/simba.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
from dataclasses import field

import flax.linen as nn
import jax.numpy as jnp

import flowrl.module.initialization as init
from flowrl.types import *


class SimbaBlock(nn.Module):
hidden_dim: int
multiplier: int = 4

@nn.compact
def __call__(
self,
x: jnp.ndarray,
training: bool = False,
) -> jnp.ndarray:
res = x
x = nn.LayerNorm()(x)
x = nn.Dense(self.hidden_dim * self.multiplier, kernel_init=init.he_normal())(x)
x = nn.relu(x)
x = nn.Dense(self.hidden_dim, kernel_init=init.he_normal())(x)
return res + x

class Simba(nn.Module):
"""
https://arxiv.org/abs/2410.09754
"""
hidden_dims: Sequence[int] = field(default_factory=lambda: [])
output_dim: int = 0
multiplier: int = 4

def setup(self):
assert len(self.hidden_dims) > 0, "hidden_dims must be non-empty"
for i in range(len(self.hidden_dims)):
assert self.hidden_dims[i] == self.hidden_dims[0], "All hidden_dims must be the same for simba"

@nn.compact
def __call__(
self,
x: jnp.ndarray,
training: bool = False,
) -> jnp.ndarray:
# First projection
x = nn.Dense(self.hidden_dims[0], kernel_init=init.orthogonal_init(1))(x)
# Residual blocks
for i, size in enumerate(self.hidden_dims):
x = SimbaBlock(
size,
self.multiplier,
)(x, training)
# Final layer norm
x = nn.LayerNorm()(x)
# Output projection
if self.output_dim > 0:
# No activation here
x = nn.Dense(self.output_dim, kernel_init=init.orthogonal_init(1))(x)
return x