Add Optimization Cookbook#5117
Conversation
|
Check out this pull request on See visual diffs & provide feedback on Jupyter Notebooks. Powered by ReviewNB |
c495dc1 to
b929529
Compare
34d7c20 to
444c6b6
Compare
fa523a9 to
5c12190
Compare
3172fdb to
b939636
Compare
f894a0d to
b37c527
Compare
|
I feel we could simplify the intro by doing the following:
model = nnx.Sequential(
nnx.Linear(2,8, rngs=rngs),
nnx.relu,
nnx.Linear(8,8, rngs=rngs),
)
optimizer = nnx.Optimizer(
model,
tx=optax.adam(1e-3),
wrt=nnx.Param)
...
@nnx.jit
def train_step(model, optimizer, ema, x, y):
loss_fn = lambda m, x, y: jnp.sum((m(x) - y) ** 2)
loss, grads = nnx.value_and_grad(loss_fn)(model, x, y)
optimizer.update(model, grads)
ema.update(model)
return loss |
|
After fully reading the guide I'm getting the sense that having the JAX versions makes explanations a bit longer and slightly harder to understand (cause you have to mentally filter for the version you are interested in) and having the JAX version doesn't necessarily make understanding the NNX version easier. |
Fair enough! I'll convert it to nnx-only. |
76f8752 to
f73edbd
Compare
8b455a1 to
1370035
Compare
1370035 to
bc62f17
Compare
bc62f17 to
00aa540
Compare
| ```python | ||
| class EmaParam(nnx.Variable): | ||
| @classmethod | ||
| def from_variable(cls, var): |
There was a problem hiding this comment.
Seems like this classmethod could just be a function in as_ema_params?
| # simulate parameter update | ||
| def double(param): | ||
| param[...] *= 2.0 | ||
| jax.tree.map(double, model, is_leaf=lambda x: isinstance(x, nnx.Variable)) |
There was a problem hiding this comment.
we could use the new nnx.map here.
What does this PR do?
This PR adds a guide that shows some common techniques for working with Flax models during optimization. These include:
This document emphasizes a style as close to pure jax as possible: to that end, it shows how the flax version of each technique only requires minor deviation from the often more intuitive pure-jax version.
Warnings: