Open
Conversation
vfdev-5
reviewed
Nov 7, 2025
| {'params': params}, batch, z_rng | ||
| ) | ||
|
|
||
| @nnx.jit |
Collaborator
There was a problem hiding this comment.
Let's use donate args to donate model and optimizer to reduce GPU memory usage.
Contributor
Author
There was a problem hiding this comment.
I tried adding donate_argnums to nnx.jit in the train_step, but was getting NaN loss and kl divergence.
what to do?
Co-authored-by: vfdev <vfdev.5@gmail.com>
vfdev-5
reviewed
Nov 11, 2025
| logging.info('Total training time: %.2f seconds', time.perf_counter() - start) | ||
|
|
||
| if __name__ == '__main__': | ||
| app.run(main) |
Collaborator
There was a problem hiding this comment.
@sanepunk why do you remove abseil app and the usage of config file?
| import ml_collections | ||
|
|
||
|
|
||
| def get_config(): |
Collaborator
There was a problem hiding this comment.
Let's keep this file and create training config using dataclass like in your examples/vae/config.py.
Finally, examples/vae/config.py can be removed. You can follow the same approach as here: https://github.com/google/flax/blob/main/examples/gemma/configs/default.py
| from config import TrainingConfig, get_default_config | ||
| import os | ||
|
|
||
| def setup_training_args(): |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Migrate VAE Example to Flax NNX with JIT Optimization
Summary
This PR migrates the VAE (Variational Autoencoder) example from Flax Linen to Flax NNX, the new simplified API. The migration includes proper use of
@nnx.jitdecorators for significant performance improvements.Changes
1. Model Architecture (
models.py)nn.Module(Linen) tonnx.Module(NNX)@nn.compactdecorators with explicit__init__methodsnnx.Linear,nnx.relu, andnnx.sigmoid2. Training Logic (
train.py)from flax import linen as nn)train_state.TrainState(not used in NNX)nnx.Optimizerfor direct state management@nnx.jitdecorator totrain_step()function@nnx.jitdecorator toeval_f()functionnnx.value_and_gradfor gradient computationnnx.Rngsjax.nn.log_sigmoid(NNX compatible)3. Code Quality
Performance Benchmarks
Training Time Comparison (30 epochs on binarized MNIST)
Hardware: CPU
Detailed Training Logs
Linen (Original) - 770.55 seconds
NNX (This PR) - 83.62 seconds ⚡
Performance Analysis
Time Saved: 686.93 seconds (11.45 minutes) for 30 epochs
Key Performance Factors:
@nnx.jit): 5-10x speedup through XLA optimizationModel Quality
Both implementations converge to similar final loss values (~100.86-100.89), demonstrating that the NNX migration maintains training quality while dramatically improving performance.
Testing
Compatibility
flax >= 0.8.0(NNX API support)configs/default.pypython main.py --workdir=/tmp/mnist --config=configs/default.pyMigration Benefits
.apply()Related Documentation
Checklist