Skip to content

Add Optimization Cookbook#5117

Merged
copybara-service[bot] merged 1 commit intogoogle:mainfrom
samanklesaria:opt_cookbook
Apr 13, 2026
Merged

Add Optimization Cookbook#5117
copybara-service[bot] merged 1 commit intogoogle:mainfrom
samanklesaria:opt_cookbook

Conversation

@samanklesaria
Copy link
Copy Markdown
Collaborator

@samanklesaria samanklesaria commented Nov 28, 2025

What does this PR do?

This PR adds a guide that shows some common techniques for working with Flax models during optimization. These include:

  • Calculation of Exponential Moving Averages
  • Optimizing only a low rank addition to certain weights (LORA)
  • Using different learning rates for different parameters to implement the maximal update parameterization
  • Using second order optimizers like LBFGS.
  • Specifying sharding for optimization state that differs from that of parameter state
  • Gradient accumulation

This document emphasizes a style as close to pure jax as possible: to that end, it shows how the flax version of each technique only requires minor deviation from the often more intuitive pure-jax version.

Warnings:

@review-notebook-app
Copy link
Copy Markdown

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

@samanklesaria samanklesaria force-pushed the opt_cookbook branch 3 times, most recently from c495dc1 to b929529 Compare December 1, 2025 23:53
@samanklesaria samanklesaria force-pushed the opt_cookbook branch 5 times, most recently from 34d7c20 to 444c6b6 Compare December 9, 2025 22:27
@samanklesaria samanklesaria marked this pull request as ready for review January 6, 2026 20:37
@samanklesaria samanklesaria force-pushed the opt_cookbook branch 2 times, most recently from f894a0d to b37c527 Compare January 20, 2026 19:56
@cgarciae
Copy link
Copy Markdown
Collaborator

I feel we could simplify the intro by doing the following:

  1. Define a single model at the begining, simply reuse it on all examples (its just a guide).
  2. Inline the loss function.
model = nnx.Sequential(
  nnx.Linear(2,8, rngs=rngs),
  nnx.relu,
  nnx.Linear(8,8, rngs=rngs),
)

optimizer = nnx.Optimizer(
  model,
  tx=optax.adam(1e-3),
  wrt=nnx.Param)

...

@nnx.jit
def train_step(model, optimizer, ema, x, y):
  loss_fn = lambda m, x, y: jnp.sum((m(x) - y) ** 2)
  loss, grads = nnx.value_and_grad(loss_fn)(model, x, y)
  optimizer.update(model, grads)
  ema.update(model)
  return loss

Comment thread docs_nnx/guides/opt_cookbook.rst Outdated
Comment thread docs_nnx/guides/opt_cookbook.rst Outdated
Comment thread docs_nnx/guides/opt_cookbook.rst Outdated
Comment thread docs_nnx/guides/opt_cookbook.rst Outdated
Comment thread docs_nnx/guides/opt_cookbook.rst Outdated
Comment thread flax/nnx/helpers.py
@cgarciae
Copy link
Copy Markdown
Collaborator

After fully reading the guide I'm getting the sense that having the JAX versions makes explanations a bit longer and slightly harder to understand (cause you have to mentally filter for the version you are interested in) and having the JAX version doesn't necessarily make understanding the NNX version easier.

@samanklesaria
Copy link
Copy Markdown
Collaborator Author

After fully reading the guide I'm getting the sense that having the JAX versions makes explanations a bit longer and slightly harder to understand (cause you have to mentally filter for the version you are interested in) and having the JAX version doesn't necessarily make understanding the NNX version easier.

Fair enough! I'll convert it to nnx-only.

@samanklesaria samanklesaria force-pushed the opt_cookbook branch 2 times, most recently from 76f8752 to f73edbd Compare February 3, 2026 16:31
@samanklesaria samanklesaria force-pushed the opt_cookbook branch 2 times, most recently from 8b455a1 to 1370035 Compare March 4, 2026 21:51
Comment thread docs_nnx/guides/optimization_cookbook.md Outdated
```python
class EmaParam(nnx.Variable):
@classmethod
def from_variable(cls, var):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like this classmethod could just be a function in as_ema_params?

# simulate parameter update
def double(param):
param[...] *= 2.0
jax.tree.map(double, model, is_leaf=lambda x: isinstance(x, nnx.Variable))
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we could use the new nnx.map here.

@copybara-service copybara-service Bot merged commit 3b475c4 into google:main Apr 13, 2026
21 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants