-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathsimba.py
More file actions
60 lines (52 loc) · 1.69 KB
/
simba.py
File metadata and controls
60 lines (52 loc) · 1.69 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
from dataclasses import field
import flax.linen as nn
import jax.numpy as jnp
import flowrl.module.initialization as init
from flowrl.types import *
class SimbaBlock(nn.Module):
hidden_dim: int
multiplier: int = 4
@nn.compact
def __call__(
self,
x: jnp.ndarray,
training: bool = False,
) -> jnp.ndarray:
res = x
x = nn.LayerNorm()(x)
x = nn.Dense(self.hidden_dim * self.multiplier, kernel_init=init.he_normal())(x)
x = nn.relu(x)
x = nn.Dense(self.hidden_dim, kernel_init=init.he_normal())(x)
return res + x
class Simba(nn.Module):
"""
https://arxiv.org/abs/2410.09754
"""
hidden_dims: Sequence[int] = field(default_factory=lambda: [])
output_dim: int = 0
multiplier: int = 4
def setup(self):
assert len(self.hidden_dims) > 0, "hidden_dims must be non-empty"
for i in range(len(self.hidden_dims)):
assert self.hidden_dims[i] == self.hidden_dims[0], "All hidden_dims must be the same for simba"
@nn.compact
def __call__(
self,
x: jnp.ndarray,
training: bool = False,
) -> jnp.ndarray:
# First projection
x = nn.Dense(self.hidden_dims[0], kernel_init=init.orthogonal_init(1))(x)
# Residual blocks
for i, size in enumerate(self.hidden_dims):
x = SimbaBlock(
size,
self.multiplier,
)(x, training)
# Final layer norm
x = nn.LayerNorm()(x)
# Output projection
if self.output_dim > 0:
# No activation here
x = nn.Dense(self.output_dim, kernel_init=init.orthogonal_init(1))(x)
return x