Skip to content

feat: add KV caching support for Wan models#400

Open
Perseus14 wants to merge 1 commit intomainfrom
wan_kv_cache
Open

feat: add KV caching support for Wan models#400
Perseus14 wants to merge 1 commit intomainfrom
wan_kv_cache

Conversation

@Perseus14
Copy link
Copy Markdown
Collaborator

@Perseus14 Perseus14 commented May 6, 2026

This Pull Request implements the KV Cache optimization for all WAN models (WAN 2.1 & 2.2, both Text-to-Video and Image-to-Video). This optimization pre-computes the Key and Value projections for text and image embeddings before the denoising loop (since they remain constant throughout)

Additionally, this PR introduces a hardware-aware Dynamic Image Alignment Padding optimization for next-generation TPUs.


Key Changes

1. KV Cache Optimization

  • Attention Level (FlaxWanAttention): Modified attention_flax.py to accept cached_kv. If present, key/value projections are bypassed. Added a robust compute_kv method to pre-project text (and image) states.
  • Transformer Level (WanModel & WanTransformerBlock): Updated transformer_wan.py to support KV cache propagation. Added compute_kv_cache to precompute block-level cached keys/values and integrated skip_embeddings inside WanTimeTextImageEmbedding to bypass redundant embedding layers when using cached states.
  • Pipeline Level (T2V & I2V Pipelines):
    • Updated forward pass helper signatures in wan_pipeline.py to accept and propagate KV caches.
    • Updated all denoising loops in wan_pipeline_2_1.py, wan_pipeline_2_2.py, wan_pipeline_i2v_2p1.py, and wan_pipeline_i2v_2p2.py to pre-compute the KV cache before starting the loop and reuse it at every step when use_kv_cache=True.
  • Config Defaults: Added use_kv_cache: False to all default .yml configuration files to ensure backward compatibility.

2. Dynamic Image Alignment Padding (Trillium & Ironwood Optimization)

  • Problem: Image embeddings were previously hardcoded to pad to multiples of 128. While optimal for older MXU tile sizes (TPU v4, v5p/v5e), next-generation hardware like Trillium (v6e) and Ironwood (v7x) utilize larger $256 \times 256$ MXU tile structures.
  • Solution: Replaced hardcoded values with dynamic TPU hardware detection (get_tpu_type()). Both attention_flax.py and embeddings_flax.py (NNXWanImageEmbedding) now dynamically adjust image alignment padding:
    • 256-alignment on Trillium (v6e) and Ironwood (v7x) to perfectly match larger hardware tiles.
    • 128-alignment fallback on older TPU architectures (v5p and below).

Detailed File Changes

Models

  • attention_flax.py:
    • Imported get_tpu_type and TpuType for dynamic hardware-aware image alignment padding.
    • Integrated cached_kv routing inside FlaxWanAttention.__call__.
    • Implemented compute_kv support for both T2V and I2V cross-attentions.
  • transformer_wan.py:
    • Added skip_embeddings parameter inside WanTimeTextImageEmbedding to bypass redundant text/image projections.
    • Updated WanTransformerBlock and WanModel to handle cached_kv / kv_cache passing.
    • Implemented WanModel.compute_kv_cache to precompute block-level cached keys/values across scan and non-scan layers.
  • embeddings_flax.py:
    • Updated NNXWanImageEmbedding to dynamically align to 256 for v6e/v7x and 128 otherwise, avoiding shape mismatches during cross-attention.

Pipelines

  • wan_pipeline.py:
    • Updated transformer_forward_pass, transformer_forward_pass_full_cfg, and transformer_forward_pass_cfg_cache to accept and pass kv_cache.
  • wan_pipeline_2_1.py & wan_pipeline_2_2.py:
    • Added use_kv_cache parameter to pipeline calls and pre-computed kv_cache and rotary_emb prior to the denoising loop.
  • wan_pipeline_i2v_2p1.py & wan_pipeline_i2v_2p2.py:
    • Fixed RoPE dummy shape bug. Integrated dynamic pre-computed kv_cache support for I2V workflows.

Configs

  • Configs (base_wan_1_3b.yml, base_wan_14b.yml, base_wan_27b.yml, base_wan_i2v_14b.yml, base_wan_i2v_27b.yml):
    • Added use_kv_cache: False for all default configs.

Performance Note

  • Observed Latency Savings:
    • ~0.5s on TPU v7x-8 (Ironwood)
    • ~0.7s on TPU v6e-8 (Trillium)
  • Analysis: The latency savings during a full denoising run are minimal. This is mathematically expected because the cross-attention Key/Value projections operate on a very small text prompt sequence (typically 512 tokens). The computational FLOPs saved by caching these projections represent a negligible fraction ($< 0.01%$) of the total workload compared to the massive latent sequence length processed by the self-attention and FFN layers at every step of the denoising loop.

@Perseus14 Perseus14 requested a review from entrpn as a code owner May 6, 2026 10:21
@github-actions
Copy link
Copy Markdown

github-actions Bot commented May 6, 2026

@Perseus14 Perseus14 requested review from mbohlool and prishajain1 May 6, 2026 10:30
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant