[Feature] Add LoRA Inference Support for WAN Models via Flax NNX #308
+1,421
−0
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Summary
This PR introduces full Low-Rank Adaptation (LoRA) inference support for the WAN family of models in MaxDiffusion.
Unlike previous implementations in this codebase that rely on
flax.linen, this implementation leveragesflax.nnx. This allows for a more Pythonic, object-oriented approach to weight injection, enabling us to modify thetransformer modelin-place.Key Features
1. Transition to
flax.nnxWAN models in MaxDiffusion are implemented using
flax.nnx. To support LoRA, we implemented a native NNX loader rather than wrappinglinenmodules.nnx.iter_graph) to identify target layers (nnx.Linear,nnx.Conv,nnx.Embed,nnx.LayerNorm) and merge LoRA weights directly into the kernel values.2. Robust Weight Merging Strategy
This implementation solves several critical distributed training/inference challenges:
jax.jit): To avoidShardingMismatchandDeviceArrayerrors that occur when mixing sharded TPU weights with CPU-based LoRA weights, all merge computations (kernel + delta) are performed within JIT-compiled functions (_compute_and_add_*_jit). This ensures weight updates occur efficiently on-device across the TPU mesh.jax.dlpackwhere possible to efficiently move PyTorch tensors to JAX arrays without unnecessary memory overhead.3. Advanced LoRA Support
Beyond standard
Linearrank reduction, this PR supports:diffweights before device-side merging.diff,diff_b): Supports checkpoints that include full-parameter fine-tuning offsets (difference injections) and bias tuning, which are common in high-fidelity WAN fine-tunes.text_embedding,time_embedding, andLayerNorm/RMSNormscales and biases.4. Scanned vs. Unscanned Layers
MaxDiffusion supports enabling
jax.scanfor transformer layers via thescan_layers: Trueconfiguration flag. This improves training memory efficiency by stacking weights of repeated layers (e.g., Attention, FFN) along a new leading dimension. Since users may run inference with or without this flag enabled, this LoRA implementation is designed to transparently support both modes.The loader distinguishes between:
merge_lora()function is used, which iterates through each layer and merges weights individually via efficient, on-device JIT calls (_compute_and_add_single_jit).merge_lora_for_scanned()function is used. It detects which parameters are stacked (e.g.,kernel.ndim > 2) and which are not._compute_and_add_scanned_jit. This updates all layers in the stack at once on-device, which is significantly more efficient than merging layer-by-layer.embeddings,proj_out): It merges them individually using the single-layer JIT logic.This dual approach ensures correct weight injection whether or not layers are scanned, while maximizing performance in scanned mode through batching.
Files Added / Modified
src/maxdiffusion/models/lora_nnx.py: [NEW] Core logic. Contains the JIT merge functions,parse_lora_dict, and the graph traversal logic (merge_lora,merge_lora_for_scanned) to inject weights into NNX modules.src/maxdiffusion/loaders/wan_lora_nnx_loader.py: [NEW] Orchestrates the loading process. Handles the download of safetensors, conversion of keys, and delegation to the merge functions.src/maxdiffusion/generate_wan.py: Updated the generation pipeline to identify iflorais enabled and trigger the loading sequence before inference.src/maxdiffusion/lora_conversion_utils.py: Updatedtranslate_wan_nnx_path_to_diffusers_lorato accurately map NNX paths (including embeddings and time projections) to Diffusers-style keys.base_wan_lora_14b.yml&base_wan_lora_27b.yml: Added lora_config section to specify LoRA checkpoints and parameters during inference.Testing