Open
Conversation
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
This Pull Request implements the KV Cache optimization for all WAN models (WAN 2.1 & 2.2, both Text-to-Video and Image-to-Video). This optimization pre-computes the Key and Value projections for text and image embeddings before the denoising loop (since they remain constant throughout)
Additionally, this PR introduces a hardware-aware Dynamic Image Alignment Padding optimization for next-generation TPUs.
Key Changes
1. KV Cache Optimization
FlaxWanAttention): Modifiedattention_flax.pyto acceptcached_kv. If present, key/value projections are bypassed. Added a robustcompute_kvmethod to pre-project text (and image) states.WanModel&WanTransformerBlock): Updatedtransformer_wan.pyto support KV cache propagation. Addedcompute_kv_cacheto precompute block-level cached keys/values and integratedskip_embeddingsinsideWanTimeTextImageEmbeddingto bypass redundant embedding layers when using cached states.wan_pipeline.pyto accept and propagate KV caches.wan_pipeline_2_1.py,wan_pipeline_2_2.py,wan_pipeline_i2v_2p1.py, andwan_pipeline_i2v_2p2.pyto pre-compute the KV cache before starting the loop and reuse it at every step whenuse_kv_cache=True.use_kv_cache: Falseto all default.ymlconfiguration files to ensure backward compatibility.2. Dynamic Image Alignment Padding (Trillium & Ironwood Optimization)
128. While optimal for older MXU tile sizes (TPU v4, v5p/v5e), next-generation hardware like Trillium (v6e) and Ironwood (v7x) utilize largerget_tpu_type()). Bothattention_flax.pyandembeddings_flax.py(NNXWanImageEmbedding) now dynamically adjust image alignment padding:v6e) and Ironwood (v7x) to perfectly match larger hardware tiles.v5pand below).Detailed File Changes
Models
attention_flax.py:get_tpu_typeandTpuTypefor dynamic hardware-aware image alignment padding.cached_kvrouting insideFlaxWanAttention.__call__.compute_kvsupport for both T2V and I2V cross-attentions.transformer_wan.py:skip_embeddingsparameter insideWanTimeTextImageEmbeddingto bypass redundant text/image projections.WanTransformerBlockandWanModelto handlecached_kv/kv_cachepassing.WanModel.compute_kv_cacheto precompute block-level cached keys/values across scan and non-scan layers.embeddings_flax.py:NNXWanImageEmbeddingto dynamically align to256forv6e/v7xand128otherwise, avoiding shape mismatches during cross-attention.Pipelines
wan_pipeline.py:transformer_forward_pass,transformer_forward_pass_full_cfg, andtransformer_forward_pass_cfg_cacheto accept and passkv_cache.wan_pipeline_2_1.py&wan_pipeline_2_2.py:use_kv_cacheparameter to pipeline calls and pre-computedkv_cacheandrotary_embprior to the denoising loop.wan_pipeline_i2v_2p1.py&wan_pipeline_i2v_2p2.py:kv_cachesupport for I2V workflows.Configs
base_wan_1_3b.yml,base_wan_14b.yml,base_wan_27b.yml,base_wan_i2v_14b.yml,base_wan_i2v_27b.yml):use_kv_cache: Falsefor all default configs.Performance Note