From 11429c0c9cfdfabe580c5e18d753bc992652df19 Mon Sep 17 00:00:00 2001 From: Rishabh Manoj Date: Tue, 5 May 2026 19:07:15 +0000 Subject: [PATCH] feat: add KV caching support for Wan models --- src/maxdiffusion/configs/base_wan_14b.yml | 1 + src/maxdiffusion/configs/base_wan_1_3b.yml | 1 + src/maxdiffusion/configs/base_wan_27b.yml | 2 +- src/maxdiffusion/configs/base_wan_i2v_14b.yml | 2 +- src/maxdiffusion/configs/base_wan_i2v_27b.yml | 2 +- src/maxdiffusion/models/attention_flax.py | 233 +++++++++++++++--- src/maxdiffusion/models/embeddings_flax.py | 7 +- .../wan/transformers/transformer_wan.py | 206 ++++++++++++++-- .../pipelines/wan/wan_pipeline.py | 94 ++++++- .../pipelines/wan/wan_pipeline_2_1.py | 65 ++++- .../pipelines/wan/wan_pipeline_2_2.py | 127 ++++++++-- .../pipelines/wan/wan_pipeline_i2v_2p1.py | 59 ++++- .../pipelines/wan/wan_pipeline_i2v_2p2.py | 177 +++++++++++-- 13 files changed, 842 insertions(+), 134 deletions(-) diff --git a/src/maxdiffusion/configs/base_wan_14b.yml b/src/maxdiffusion/configs/base_wan_14b.yml index c2c83c9f7..f432928aa 100644 --- a/src/maxdiffusion/configs/base_wan_14b.yml +++ b/src/maxdiffusion/configs/base_wan_14b.yml @@ -355,6 +355,7 @@ use_cfg_cache: False # Batch positive and negative prompts in text encoder to save compute. use_batched_text_encoder: False +use_kv_cache: False use_magcache: False magcache_thresh: 0.12 magcache_K: 2 diff --git a/src/maxdiffusion/configs/base_wan_1_3b.yml b/src/maxdiffusion/configs/base_wan_1_3b.yml index 1fd384eb1..0e0552656 100644 --- a/src/maxdiffusion/configs/base_wan_1_3b.yml +++ b/src/maxdiffusion/configs/base_wan_1_3b.yml @@ -301,6 +301,7 @@ flow_shift: 3.0 # Diffusion CFG cache (FasterCache-style, WAN 2.1 T2V only) use_cfg_cache: False +use_kv_cache: False # Batch positive and negative prompts in text encoder to save compute. use_batched_text_encoder: False diff --git a/src/maxdiffusion/configs/base_wan_27b.yml b/src/maxdiffusion/configs/base_wan_27b.yml index 1ce67a3cf..bf29fa867 100644 --- a/src/maxdiffusion/configs/base_wan_27b.yml +++ b/src/maxdiffusion/configs/base_wan_27b.yml @@ -331,7 +331,7 @@ use_cfg_cache: False # Batch positive and negative prompts in text encoder to save compute. use_batched_text_encoder: False - +use_kv_cache: False # SenCache: Sensitivity-Aware Caching (arXiv:2602.24208) — skip forward pass # when predicted output change (based on accumulated latent/timestep drift) is small use_sen_cache: False diff --git a/src/maxdiffusion/configs/base_wan_i2v_14b.yml b/src/maxdiffusion/configs/base_wan_i2v_14b.yml index 214cf5ce4..2af520727 100644 --- a/src/maxdiffusion/configs/base_wan_i2v_14b.yml +++ b/src/maxdiffusion/configs/base_wan_i2v_14b.yml @@ -318,7 +318,7 @@ use_cfg_cache: False # Batch positive and negative prompts in text encoder to save compute. use_batched_text_encoder: False - +use_kv_cache: False # SenCache: Sensitivity-Aware Caching (arXiv:2602.24208) use_sen_cache: False use_magcache: False diff --git a/src/maxdiffusion/configs/base_wan_i2v_27b.yml b/src/maxdiffusion/configs/base_wan_i2v_27b.yml index d2eb451d4..f5536a1cf 100644 --- a/src/maxdiffusion/configs/base_wan_i2v_27b.yml +++ b/src/maxdiffusion/configs/base_wan_i2v_27b.yml @@ -330,7 +330,7 @@ use_cfg_cache: False # Batch positive and negative prompts in text encoder to save compute. use_batched_text_encoder: False - +use_kv_cache: False # SenCache: Sensitivity-Aware Caching (arXiv:2602.24208) use_sen_cache: False diff --git a/src/maxdiffusion/models/attention_flax.py b/src/maxdiffusion/models/attention_flax.py index ae938b541..155309250 100644 --- a/src/maxdiffusion/models/attention_flax.py +++ b/src/maxdiffusion/models/attention_flax.py @@ -15,7 +15,7 @@ import contextlib import functools import math -from typing import Optional, Callable, Tuple +from typing import Optional, Callable, Tuple, Dict import flax.linen as nn from flax import nnx import jax @@ -30,6 +30,8 @@ from maxdiffusion.kernels.splash_attention import base as tokamax_splash_base from einops import rearrange from .. import common_types, max_logging +from maxdiffusion.tpu_utils import get_tpu_type, TpuType + from ..kernels import custom_splash_attention as custom_splash from . import quantizations @@ -677,7 +679,13 @@ def wrap_ulysses_attention(query, key, value): # Restore the original layout expected by the rest of the model: # head-sharded / full-sequence -> sequence-sharded / full-heads. - attention_output = jax.lax.all_to_all(attention_output, axis_name=axis_name, split_axis=2, concat_axis=1, tiled=True) + attention_output = jax.lax.all_to_all( + attention_output, + axis_name=axis_name, + split_axis=2, + concat_axis=1, + tiled=True, + ) return attention_output devices_in_batch_sharding = mesh.shape["data"] * (mesh.shape["fsdp"] if "fsdp" in mesh.shape else 1) @@ -739,7 +747,11 @@ def _apply_attention_dot( query_chunk_size = int(flatten_latent_dim) hidden_states = jax_memory_efficient_attention( - query_states, key_states, value_states, query_chunk_size=query_chunk_size, key_chunk_size=4096 * 4 + query_states, + key_states, + value_states, + query_chunk_size=query_chunk_size, + key_chunk_size=4096 * 4, ) hidden_states = hidden_states.transpose(1, 0, 2) @@ -1040,7 +1052,12 @@ def chunk_scanner(chunk_idx): def jax_memory_efficient_attention( - query, key, value, precision=jax.lax.Precision.HIGHEST, query_chunk_size: int = 1024, key_chunk_size: int = 4096 + query, + key, + value, + precision=jax.lax.Precision.HIGHEST, + query_chunk_size: int = 1024, + key_chunk_size: int = 4096, ): r""" Flax Memory-efficient multi-head dot product attention. https://arxiv.org/abs/2112.05682v2 @@ -1072,11 +1089,20 @@ def chunk_scanner(chunk_idx, _): return ( chunk_idx + query_chunk_size, # unused ignore it - _query_chunk_attention(query=query_chunk, key=key, value=value, precision=precision, key_chunk_size=key_chunk_size), + _query_chunk_attention( + query=query_chunk, + key=key, + value=value, + precision=precision, + key_chunk_size=key_chunk_size, + ), ) _, res = jax.lax.scan( - f=chunk_scanner, init=0, xs=None, length=math.ceil(num_q / query_chunk_size) # start counter # stop counter + f=chunk_scanner, + init=0, + xs=None, + length=math.ceil(num_q / query_chunk_size), # start counter # stop counter ) return jnp.concatenate(res, axis=-3) # fuse the chunked result back @@ -1547,6 +1573,7 @@ def __call__( encoder_attention_mask: Optional[jax.Array] = None, deterministic: bool = True, rngs: nnx.Rngs = None, + cached_kv: Optional[Dict[str, Tuple[jax.Array, jax.Array]]] = None, ) -> jax.Array: axis_names = nn.logical_to_mesh_axes((BATCH, LENGTH, HEAD)) hidden_states = jax.lax.with_sharding_constraint(hidden_states, axis_names) @@ -1566,16 +1593,22 @@ def __call__( if not is_i2v_cross_attention: with jax.named_scope("query_proj"): query_proj = self.query(hidden_states) - with jax.named_scope("key_proj"): - key_proj = self.key(encoder_hidden_states) - with jax.named_scope("value_proj"): - value_proj = self.value(encoder_hidden_states) if self.qk_norm: with self.conditional_named_scope("attn_q_norm"): query_proj = self.norm_q(query_proj) - with self.conditional_named_scope("attn_k_norm"): - key_proj = self.norm_k(key_proj) + + if not is_self_attention and cached_kv is not None and "text" in cached_kv: + key_proj, value_proj = cached_kv["text"] + else: + with jax.named_scope("key_proj"): + key_proj = self.key(encoder_hidden_states) + with jax.named_scope("value_proj"): + value_proj = self.value(encoder_hidden_states) + + if self.qk_norm: + with self.conditional_named_scope("attn_k_norm"): + key_proj = self.norm_k(key_proj) if rotary_emb is not None: with self.conditional_named_scope("attn_rope"): @@ -1591,7 +1624,10 @@ def __call__( with jax.named_scope("apply_attention"): attn_output = self.attention_op.apply_attention( - query_proj, key_proj, value_proj, attention_mask=encoder_attention_mask + query_proj, + key_proj, + value_proj, + attention_mask=encoder_attention_mask, ) else: @@ -1599,10 +1635,11 @@ def __call__( with self.conditional_named_scope("proj_query"): query_proj_raw = self.query(hidden_states) - # Image embeddings are padded to multiples of 128 for TPU flash attention + # Image embeddings are padded to multiples of 128 (v5p and below) or 256 (v6e and above) for TPU flash attention # Calculate the padded length to correctly split image and text embeddings if self.added_kv_proj_dim is not None: - alignment = 128 + tpu_type = get_tpu_type() + alignment = 256 if tpu_type in [TpuType.TPU_V6_LITE, TpuType.TPU_7X] else 128 if self.image_seq_len is not None: image_seq_len_actual = self.image_seq_len else: @@ -1635,22 +1672,28 @@ def __call__( query_proj_text = query_proj_raw # Text K/V - with self.conditional_named_scope("proj_key"): - key_proj_text = self.key(encoder_hidden_states_text) - if self.qk_norm: - with self.conditional_named_scope("attn_k_norm"): - key_proj_text = self.norm_k(key_proj_text) - with self.conditional_named_scope("proj_value"): - value_proj_text = self.value(encoder_hidden_states_text) + if cached_kv is not None and "text" in cached_kv: + key_proj_text, value_proj_text = cached_kv["text"] + else: + with self.conditional_named_scope("proj_key"): + key_proj_text = self.key(encoder_hidden_states_text) + if self.qk_norm: + with self.conditional_named_scope("attn_k_norm"): + key_proj_text = self.norm_k(key_proj_text) + with self.conditional_named_scope("proj_value"): + value_proj_text = self.value(encoder_hidden_states_text) # Image K/V (only if image embeddings are present) if encoder_hidden_states_img is not None: - with self.conditional_named_scope("add_proj_k"): - key_proj_img = self.add_k_proj(encoder_hidden_states_img) - with self.conditional_named_scope("norm_add_k"): - key_proj_img = self.norm_added_k(key_proj_img) - with self.conditional_named_scope("add_proj_v"): - value_proj_img = self.add_v_proj(encoder_hidden_states_img) + if cached_kv is not None and "image" in cached_kv: + key_proj_img, value_proj_img = cached_kv["image"] + else: + with self.conditional_named_scope("add_proj_k"): + key_proj_img = self.add_k_proj(encoder_hidden_states_img) + with self.conditional_named_scope("norm_add_k"): + key_proj_img = self.norm_added_k(key_proj_img) + with self.conditional_named_scope("add_proj_v"): + value_proj_img = self.add_v_proj(encoder_hidden_states_img) query_proj_img = query_proj_raw # Check norm_added_k too # Checkpointing @@ -1667,7 +1710,10 @@ def __call__( with self.conditional_named_scope("cross_attn_img_apply"): # Pass encoder_attention_mask_img for image cross-attention to mask padded tokens attn_output_img = self.attention_op.apply_attention( - query_proj_img, key_proj_img, value_proj_img, attention_mask=encoder_attention_mask_img + query_proj_img, + key_proj_img, + value_proj_img, + attention_mask=encoder_attention_mask_img, ) attn_output = attn_output_text + attn_output_img @@ -1689,6 +1735,65 @@ def __call__( hidden_states = self.drop_out(hidden_states, deterministic=deterministic, rngs=rngs) return hidden_states + def compute_kv( + self, + encoder_hidden_states: jax.Array, + encoder_attention_mask: Optional[jax.Array] = None, + ) -> Dict[str, Tuple[jax.Array, jax.Array]]: + is_i2v_cross_attention = self.added_kv_proj_dim is not None + + if not is_i2v_cross_attention: + with jax.named_scope("key_proj"): + key_proj = self.key(encoder_hidden_states) + with jax.named_scope("value_proj"): + value_proj = self.value(encoder_hidden_states) + + if self.qk_norm: + with self.conditional_named_scope("attn_k_norm"): + key_proj = self.norm_k(key_proj) + + return {"text": (key_proj, value_proj)} + else: + # Image embeddings are padded to multiples of 128 (v5p and below) or 256 (v6e and above) for TPU flash attention + tpu_type = get_tpu_type() + alignment = 256 if tpu_type in [TpuType.TPU_V6_LITE, TpuType.TPU_7X] else 128 + if self.image_seq_len is not None: + image_seq_len_actual = self.image_seq_len + else: + image_seq_len_actual = 257 + padded_img_len = ((image_seq_len_actual + alignment - 1) // alignment) * alignment + + if encoder_attention_mask is None: + padded_img_len = image_seq_len_actual + + encoder_hidden_states_img = encoder_hidden_states[:, :padded_img_len, :] + encoder_hidden_states_text = encoder_hidden_states[:, padded_img_len:, :] + + # Text K/V + with self.conditional_named_scope("proj_key"): + key_proj_text = self.key(encoder_hidden_states_text) + if self.qk_norm: + with self.conditional_named_scope("attn_k_norm"): + key_proj_text = self.norm_k(key_proj_text) + with self.conditional_named_scope("proj_value"): + value_proj_text = self.value(encoder_hidden_states_text) + + # Image K/V (only if image embeddings are present) + if encoder_hidden_states_img is not None: + with self.conditional_named_scope("add_proj_k"): + key_proj_img = self.add_k_proj(encoder_hidden_states_img) + with self.conditional_named_scope("norm_add_k"): + key_proj_img = self.norm_added_k(key_proj_img) + with self.conditional_named_scope("add_proj_v"): + value_proj_img = self.add_v_proj(encoder_hidden_states_img) + + return { + "text": (key_proj_text, value_proj_text), + "image": (key_proj_img, value_proj_img), + } + else: + return {"text": (key_proj_text, value_proj_text)} + class FlaxFluxAttention(nn.Module): query_dim: int @@ -1801,7 +1906,13 @@ def setup(self): param_dtype=self.weights_dtype, ) - def __call__(self, hidden_states, encoder_hidden_states=None, attention_mask=None, image_rotary_emb=None): + def __call__( + self, + hidden_states, + encoder_hidden_states=None, + attention_mask=None, + image_rotary_emb=None, + ): qkv_proj = self.qkv(hidden_states) B, L = hidden_states.shape[:2] H, D, K = self.heads, qkv_proj.shape[-1] // (self.heads * 3), 3 @@ -1973,7 +2084,13 @@ def setup(self): ) self.dropout_layer = nn.Dropout(rate=self.dropout) - def __call__(self, hidden_states, context=None, deterministic=True, cross_attention_kwargs=None): + def __call__( + self, + hidden_states, + context=None, + deterministic=True, + cross_attention_kwargs=None, + ): context = hidden_states if context is None else context query_proj = self.query(hidden_states) key_proj = self.key(context) @@ -2077,7 +2194,11 @@ def setup(self): quant=self.quant, ) self.ff = FlaxFeedForward( - dim=self.dim, dropout=self.dropout, dtype=self.dtype, weights_dtype=self.weights_dtype, precision=self.precision + dim=self.dim, + dropout=self.dropout, + dtype=self.dtype, + weights_dtype=self.weights_dtype, + precision=self.precision, ) self.norm1 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype, param_dtype=self.weights_dtype) self.norm2 = nn.LayerNorm(epsilon=1e-5, dtype=self.dtype, param_dtype=self.weights_dtype) @@ -2089,11 +2210,16 @@ def __call__(self, hidden_states, context, deterministic=True, cross_attention_k residual = hidden_states if self.only_cross_attention: hidden_states = self.attn1( - self.norm1(hidden_states), context, deterministic=deterministic, cross_attention_kwargs=cross_attention_kwargs + self.norm1(hidden_states), + context, + deterministic=deterministic, + cross_attention_kwargs=cross_attention_kwargs, ) else: hidden_states = self.attn1( - self.norm1(hidden_states), deterministic=deterministic, cross_attention_kwargs=cross_attention_kwargs + self.norm1(hidden_states), + deterministic=deterministic, + cross_attention_kwargs=cross_attention_kwargs, ) hidden_states = hidden_states + residual @@ -2101,7 +2227,10 @@ def __call__(self, hidden_states, context, deterministic=True, cross_attention_k # cross attention residual = hidden_states hidden_states = self.attn2( - self.norm2(hidden_states), context, deterministic=deterministic, cross_attention_kwargs=cross_attention_kwargs + self.norm2(hidden_states), + context, + deterministic=deterministic, + cross_attention_kwargs=cross_attention_kwargs, ) hidden_states = hidden_states + residual @@ -2172,7 +2301,12 @@ class FlaxTransformer2DModel(nn.Module): quant: Quant = (None,) def setup(self): - self.norm = nn.GroupNorm(num_groups=self.norm_num_groups, epsilon=1e-5, dtype=self.dtype, param_dtype=self.weights_dtype) + self.norm = nn.GroupNorm( + num_groups=self.norm_num_groups, + epsilon=1e-5, + dtype=self.dtype, + param_dtype=self.weights_dtype, + ) conv_kernel_init = nn.with_logical_partitioning( nn.initializers.lecun_normal(), ("keep_1", "keep_2", "conv_in", "conv_out") @@ -2255,7 +2389,10 @@ def __call__(self, hidden_states, context, deterministic=True, cross_attention_k for transformer_block in self.transformer_blocks: hidden_states = transformer_block( - hidden_states, context, deterministic=deterministic, cross_attention_kwargs=cross_attention_kwargs + hidden_states, + context, + deterministic=deterministic, + cross_attention_kwargs=cross_attention_kwargs, ) if self.use_linear_projection: @@ -2298,8 +2435,19 @@ class FlaxFeedForward(nn.Module): def setup(self): # The second linear layer needs to be called # net_2 for now to match the index of the Sequential layer - self.net_0 = FlaxGEGLU(self.dim, self.dropout, self.dtype, self.weights_dtype, precision=self.precision) - self.net_2 = nn.Dense(self.dim, dtype=self.dtype, param_dtype=self.weights_dtype, precision=self.precision) + self.net_0 = FlaxGEGLU( + self.dim, + self.dropout, + self.dtype, + self.weights_dtype, + precision=self.precision, + ) + self.net_2 = nn.Dense( + self.dim, + dtype=self.dtype, + param_dtype=self.weights_dtype, + precision=self.precision, + ) def __call__(self, hidden_states, deterministic=True): hidden_states = self.net_0(hidden_states, deterministic=deterministic) @@ -2329,7 +2477,12 @@ class FlaxGEGLU(nn.Module): def setup(self): inner_dim = self.dim * 4 - self.proj = nn.Dense(inner_dim * 2, dtype=self.dtype, param_dtype=self.weights_dtype, precision=self.precision) + self.proj = nn.Dense( + inner_dim * 2, + dtype=self.dtype, + param_dtype=self.weights_dtype, + precision=self.precision, + ) self.dropout_layer = nn.Dropout(rate=self.dropout) def __call__(self, hidden_states, deterministic=True): diff --git a/src/maxdiffusion/models/embeddings_flax.py b/src/maxdiffusion/models/embeddings_flax.py index 36bea9ea3..01dfdec34 100644 --- a/src/maxdiffusion/models/embeddings_flax.py +++ b/src/maxdiffusion/models/embeddings_flax.py @@ -21,6 +21,7 @@ from .modeling_flax_utils import get_activation from ..models.attention_flax import NNXSimpleFeedForward from ..models.normalization_flax import FP32LayerNorm +from maxdiffusion.tpu_utils import get_tpu_type, TpuType def get_sinusoidal_embeddings( @@ -275,7 +276,11 @@ def __init__( precision=precision, ) self.norm2 = FP32LayerNorm(rngs=rngs, dim=out_features, elementwise_affine=True, eps=1e-6) - self.alignment = alignment + if alignment == 128: + tpu_type = get_tpu_type() + self.alignment = 256 if tpu_type in [TpuType.TPU_V6_LITE, TpuType.TPU_7X] else 128 + else: + self.alignment = alignment self.flash_min_seq_length = flash_min_seq_length if pos_embed_seq_len is not None: self.pos_embed = nnx.Param(jnp.zeros((1, pos_embed_seq_len, in_features), dtype=dtype)) diff --git a/src/maxdiffusion/models/wan/transformers/transformer_wan.py b/src/maxdiffusion/models/wan/transformers/transformer_wan.py index 7d721773e..faaf50ee2 100644 --- a/src/maxdiffusion/models/wan/transformers/transformer_wan.py +++ b/src/maxdiffusion/models/wan/transformers/transformer_wan.py @@ -62,7 +62,13 @@ def get_frequencies(max_seq_len: int, theta: int, attention_head_dim: int): class WanRotaryPosEmbed(nnx.Module): - def __init__(self, attention_head_dim: int, patch_size: Tuple[int, int, int], max_seq_len: int, theta: float = 10000.0): + def __init__( + self, + attention_head_dim: int, + patch_size: Tuple[int, int, int], + max_seq_len: int, + theta: float = 10000.0, + ): self.attention_head_dim = attention_head_dim self.patch_size = patch_size self.max_seq_len = max_seq_len @@ -152,18 +158,35 @@ def __init__( ) def __call__( - self, timestep: jax.Array, encoder_hidden_states: jax.Array, encoder_hidden_states_image: Optional[jax.Array] = None + self, + timestep: jax.Array, + encoder_hidden_states: jax.Array, + encoder_hidden_states_image: Optional[jax.Array] = None, + skip_embeddings: bool = False, ): timestep = self.timesteps_proj(timestep) temb = self.time_embedder(timestep) with jax.named_scope("time_proj"): timestep_proj = self.time_proj(self.act_fn(temb)) - encoder_hidden_states = self.text_embedder(encoder_hidden_states) - encoder_attention_mask = None - if encoder_hidden_states_image is not None: - encoder_hidden_states_image, encoder_attention_mask = self.image_embedder(encoder_hidden_states_image) - return temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image, encoder_attention_mask + if not skip_embeddings: + encoder_hidden_states = self.text_embedder(encoder_hidden_states) + encoder_attention_mask = None + if encoder_hidden_states_image is not None: + ( + encoder_hidden_states_image, + encoder_attention_mask, + ) = self.image_embedder(encoder_hidden_states_image) + else: + encoder_attention_mask = None + + return ( + temb, + timestep_proj, + encoder_hidden_states, + encoder_hidden_states_image, + encoder_attention_mask, + ) class ApproximateGELU(nnx.Module): @@ -232,7 +255,13 @@ def __init__( self.act_fn = nnx.data(None) if activation_fn == "gelu-approximate": self.act_fn = ApproximateGELU( - rngs=rngs, dim_in=dim, dim_out=inner_dim, bias=bias, dtype=dtype, weights_dtype=weights_dtype, precision=precision + rngs=rngs, + dim_in=dim, + dim_out=inner_dim, + bias=bias, + dtype=dtype, + weights_dtype=weights_dtype, + precision=precision, ) else: raise NotImplementedError(f"{activation_fn} is not implemented.") @@ -259,7 +288,12 @@ def conditional_named_scope(self, name: str): """Return a JAX named scope if enabled, otherwise a null context.""" return jax.named_scope(name) if self.enable_jax_named_scopes else contextlib.nullcontext() - def __call__(self, hidden_states: jax.Array, deterministic: bool = True, rngs: nnx.Rngs = None) -> jax.Array: + def __call__( + self, + hidden_states: jax.Array, + deterministic: bool = True, + rngs: nnx.Rngs = None, + ) -> jax.Array: hidden_states = self.act_fn(hidden_states) # Output is (4, 75600, 13824) hidden_states = checkpoint_name(hidden_states, "ffn_activation") if self.drop_out.rate > 0: @@ -381,6 +415,7 @@ def __call__( deterministic: bool = True, rngs: nnx.Rngs = None, encoder_attention_mask: Optional[jax.Array] = None, + cached_kv: Optional[Dict[str, Tuple[jax.Array, jax.Array]]] = None, ): with self.conditional_named_scope("transformer_block"): # Support both global [B, 6, dim] and per-token [B, seq_len, 6, dim] temb. @@ -396,7 +431,14 @@ def __call__( c_scale_msa = parts[4].squeeze(2) c_gate_msa = parts[5].squeeze(2) else: # Global: [B, 6, dim] - shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = jnp.split( + ( + shift_msa, + scale_msa, + gate_msa, + c_shift_msa, + c_scale_msa, + c_gate_msa, + ) = jnp.split( (self.adaln_scale_shift_table + temb.astype(jnp.float32)), 6, axis=1, @@ -435,6 +477,7 @@ def __call__( deterministic=deterministic, rngs=rngs, encoder_attention_mask=encoder_attention_mask, + cached_kv=cached_kv, ) with self.conditional_named_scope("cross_attn_residual"): hidden_states = hidden_states + attn_output @@ -453,6 +496,13 @@ def __call__( ) return hidden_states + def compute_kv( + self, + encoder_hidden_states: jax.Array, + encoder_attention_mask: Optional[jax.Array] = None, + ) -> Dict[str, Tuple[jax.Array, jax.Array]]: + return self.attn2.compute_kv(encoder_hidden_states, encoder_attention_mask) + class WanModel(nnx.Module, FlaxModelMixin, ConfigMixin): @@ -533,7 +583,11 @@ def __init__( # 3. Transformer blocks @nnx.split_rngs(splits=num_layers) - @nnx.vmap(in_axes=0, out_axes=0, transform_metadata={nnx.PARTITION_NAME: "layers_per_stage"}) + @nnx.vmap( + in_axes=0, + out_axes=0, + transform_metadata={nnx.PARTITION_NAME: "layers_per_stage"}, + ) def init_block(rngs): return WanTransformerBlock( rngs=rngs, @@ -609,6 +663,61 @@ def conditional_named_scope(self, name: str): """Return a JAX named scope if enabled, otherwise a null context.""" return jax.named_scope(name) if self.enable_jax_named_scopes else contextlib.nullcontext() + def compute_kv_cache( + self, + encoder_hidden_states: jax.Array, + encoder_hidden_states_image: Optional[jax.Array] = None, + timestep: Optional[jax.Array] = None, + ) -> Tuple[Dict[str, Tuple[jax.Array, jax.Array]], Optional[jax.Array]]: + if timestep is None: + batch_size = encoder_hidden_states.shape[0] + timestep = jnp.zeros((batch_size,), dtype=jnp.int32) + + with self.conditional_named_scope("condition_embedder"): + ( + temb, + timestep_proj, + encoder_hidden_states, + encoder_hidden_states_image, + encoder_attention_mask, + ) = self.condition_embedder(timestep, encoder_hidden_states, encoder_hidden_states_image) + + if encoder_hidden_states_image is not None: + encoder_hidden_states = jnp.concatenate([encoder_hidden_states_image, encoder_hidden_states], axis=1) + if encoder_attention_mask is not None: + text_mask = jnp.ones( + ( + encoder_hidden_states.shape[0], + encoder_hidden_states.shape[1] - encoder_hidden_states_image.shape[1], + ), + dtype=jnp.int32, + ) + encoder_attention_mask = jnp.concatenate([encoder_attention_mask, text_mask], axis=1) + + if self.scan_layers: + + @nnx.vmap( + in_axes=(0, None, None), + out_axes=0, + transform_metadata={nnx.PARTITION_NAME: "layers_per_stage"}, + ) + def _compute_kv(block, enc_states, enc_mask): + return block.compute_kv(enc_states, enc_mask) + + kv_cache = _compute_kv(self.blocks, encoder_hidden_states, encoder_attention_mask) + else: + kv_cache_list = [] + for block in self.blocks: + kv_cache_list.append(block.compute_kv(encoder_hidden_states, encoder_attention_mask)) + keys = kv_cache_list[0].keys() + kv_cache = {} + for k in keys: + k_list = [d[k][0] for d in kv_cache_list] + v_list = [d[k][1] for d in kv_cache_list] + kv_cache[k] = (jnp.stack(k_list, axis=0), jnp.stack(v_list, axis=0)) + + return kv_cache, encoder_attention_mask + @jax.named_scope("WanModel") def __call__( self, @@ -623,6 +732,9 @@ def __call__( skip_blocks: Optional[jax.Array] = None, cached_residual: Optional[jax.Array] = None, return_residual: bool = False, + kv_cache: Optional[Dict[str, Tuple[jax.Array, jax.Array]]] = None, + rotary_emb: Optional[jax.Array] = None, + encoder_attention_mask: Optional[jax.Array] = None, ) -> Union[jax.Array, Tuple[jax.Array, jax.Array], Dict[str, jax.Array]]: hidden_states = nn.with_logical_constraint(hidden_states, ("batch", None, None, None, None)) batch_size, _, num_frames, height, width = hidden_states.shape @@ -633,7 +745,8 @@ def __call__( hidden_states = jnp.transpose(hidden_states, (0, 2, 3, 4, 1)) with self.conditional_named_scope("rotary_embedding"): - rotary_emb = self.rope(hidden_states) + if rotary_emb is None: + rotary_emb = self.rope(hidden_states) with self.conditional_named_scope("patch_embedding"): hidden_states = self.patch_embedding(hidden_states) hidden_states = jax.lax.collapse(hidden_states, 1, -1) @@ -659,27 +772,51 @@ def __call__( ( temb, timestep_proj, + encoder_hidden_states_out, + encoder_hidden_states_image_out, + encoder_attention_mask_out, + ) = self.condition_embedder( + timestep, encoder_hidden_states, encoder_hidden_states_image, - encoder_attention_mask, - ) = self.condition_embedder(timestep, encoder_hidden_states, encoder_hidden_states_image) + skip_embeddings=(kv_cache is not None), + ) timestep_proj = timestep_proj.reshape(timestep_proj.shape[0], 6, -1) + if kv_cache is not None and encoder_attention_mask is not None: + encoder_attention_mask = encoder_attention_mask + else: + encoder_attention_mask = encoder_attention_mask_out + if encoder_hidden_states_image is not None: - encoder_hidden_states = jnp.concatenate([encoder_hidden_states_image, encoder_hidden_states], axis=1) - if encoder_attention_mask is not None: + encoder_hidden_states = jnp.concatenate([encoder_hidden_states_image, encoder_hidden_states_out], axis=1) + if kv_cache is None and encoder_attention_mask is not None: text_mask = jnp.ones( - (encoder_hidden_states.shape[0], encoder_hidden_states.shape[1] - encoder_hidden_states_image.shape[1]), + ( + encoder_hidden_states.shape[0], + encoder_hidden_states.shape[1] - encoder_hidden_states_image.shape[1], + ), dtype=jnp.int32, ) encoder_attention_mask = jnp.concatenate([encoder_attention_mask, text_mask], axis=1) encoder_hidden_states = encoder_hidden_states.astype(hidden_states.dtype) + else: + if per_token_t: + encoder_hidden_states = encoder_hidden_states.astype(hidden_states.dtype) + else: + encoder_hidden_states = encoder_hidden_states_out.astype(hidden_states.dtype) def _run_all_blocks(h): if self.scan_layers: - def scan_fn(carry, block): + def scan_fn(carry, block_input): hidden_states_carry, rngs_carry = carry + if kv_cache is not None: + block, layer_kv_cache = block_input + else: + block = block_input + layer_kv_cache = None + hidden_states = block( hidden_states_carry, encoder_hidden_states, @@ -688,27 +825,40 @@ def scan_fn(carry, block): deterministic, rngs_carry, encoder_attention_mask, + cached_kv=layer_kv_cache, ) new_carry = (hidden_states, rngs_carry) return new_carry, None rematted_block_forward = self.gradient_checkpoint.apply( - scan_fn, self.names_which_can_be_saved, self.names_which_can_be_offloaded, prevent_cse=not self.scan_layers + scan_fn, + self.names_which_can_be_saved, + self.names_which_can_be_offloaded, + prevent_cse=not self.scan_layers, ) initial_carry = (h, rngs) + + if kv_cache is not None: + scan_input = (self.blocks, kv_cache) + else: + scan_input = self.blocks + final_carry, _ = nnx.scan( rematted_block_forward, length=self.num_layers, in_axes=(nnx.Carry, 0), out_axes=(nnx.Carry, 0), - )(initial_carry, self.blocks) + )(initial_carry, scan_input) h_out, _ = final_carry else: h_out = h - for block in self.blocks: + for i, block in enumerate(self.blocks): + layer_kv_cache = None + if kv_cache is not None: + layer_kv_cache = jax.tree_map(lambda x: x[i], kv_cache) - def layer_forward(hidden_states): + def layer_forward(hidden_states, l_kv): return block( hidden_states, encoder_hidden_states, @@ -717,6 +867,7 @@ def layer_forward(hidden_states): deterministic, rngs, encoder_attention_mask=encoder_attention_mask, + cached_kv=l_kv, ) rematted_layer_forward = self.gradient_checkpoint.apply( @@ -725,7 +876,7 @@ def layer_forward(hidden_states): self.names_which_can_be_offloaded, prevent_cse=not self.scan_layers, ) - h_out = rematted_layer_forward(h_out) + h_out = rematted_layer_forward(h_out, layer_kv_cache) return h_out hidden_states_before_blocks = hidden_states @@ -752,7 +903,14 @@ def layer_forward(hidden_states): hidden_states = self.proj_out(hidden_states) hidden_states = hidden_states.reshape( - batch_size, post_patch_num_frames, post_patch_height, post_patch_width, p_t, p_h, p_w, -1 + batch_size, + post_patch_num_frames, + post_patch_height, + post_patch_width, + p_t, + p_h, + p_w, + -1, ) hidden_states = jnp.transpose(hidden_states, (0, 7, 1, 4, 2, 5, 3, 6)) hidden_states = hidden_states.reshape(batch_size, -1, num_frames, height, width) diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline.py b/src/maxdiffusion/pipelines/wan/wan_pipeline.py index 608f7282d..8429eaf74 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline.py @@ -173,7 +173,8 @@ def create_model(rngs: nnx.Rngs, wan_config: dict): ) params = jax.tree_util.tree_map_with_path( - lambda path, x: cast_with_exclusion(path, x, dtype_to_cast=config.weights_dtype), params + lambda path, x: cast_with_exclusion(path, x, dtype_to_cast=config.weights_dtype), + params, ) for path, val in flax.traverse_util.flatten_dict(params).items(): if restored_checkpoint: @@ -291,7 +292,9 @@ def load_image_encoder(cls, config: HyperParameters): image_processor = CLIPImageProcessor.from_pretrained(config.pretrained_model_name_or_path, subfolder="image_processor") try: image_encoder = FlaxCLIPVisionModel.from_pretrained( - config.pretrained_model_name_or_path, subfolder="image_encoder", dtype=jnp.float32 + config.pretrained_model_name_or_path, + subfolder="image_encoder", + dtype=jnp.float32, ) except Exception as e: max_logging.error(f"Failed to load FlaxCLIPVisionModel: {e}") @@ -300,7 +303,12 @@ def load_image_encoder(cls, config: HyperParameters): @classmethod def load_vae( - cls, devices_array: np.array, mesh: Mesh, rngs: nnx.Rngs, config: HyperParameters, vae_logical_axis_rules: tuple = None + cls, + devices_array: np.array, + mesh: Mesh, + rngs: nnx.Rngs, + config: HyperParameters, + vae_logical_axis_rules: tuple = None, ): def create_model(rngs: nnx.Rngs, config: HyperParameters): wan_vae = AutoencoderKLWan.from_config( @@ -403,7 +411,13 @@ def get_qt_provider(cls, config: HyperParameters) -> Optional[qwix.QtProvider]: return None @classmethod - def quantize_transformer(cls, config: HyperParameters, model: WanModel, pipeline: "WanPipeline", mesh: Mesh): + def quantize_transformer( + cls, + config: HyperParameters, + model: WanModel, + pipeline: "WanPipeline", + mesh: Mesh, + ): """Quantizes the transformer model.""" q_rules = cls.get_qt_provider(config) if not q_rules: @@ -484,7 +498,8 @@ def _get_t5_prompt_embeds( prompt_embeds = self.text_encoder(text_input_ids, mask).last_hidden_state prompt_embeds = [u[:v] for u, v in zip(prompt_embeds, seq_lens)] prompt_embeds = torch.stack( - [torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], dim=0 + [torch.cat([u, u.new_zeros(max_sequence_length - u.size(0), u.size(1))]) for u in prompt_embeds], + dim=0, ) # duplicate text embeddings for each generation per prompt, using mps friendly method @@ -587,12 +602,26 @@ def prepare_latents_i2v_base( if last_image is None: video_condition = jnp.concatenate( - [image, jnp.zeros((image.shape[0], image.shape[1], num_frames - 1, height, width), dtype=image.dtype)], axis=2 + [ + image, + jnp.zeros( + (image.shape[0], image.shape[1], num_frames - 1, height, width), + dtype=image.dtype, + ), + ], + axis=2, ) else: last_image = last_image[:, :, jnp.newaxis, :, :] video_condition = jnp.concatenate( - [image, jnp.zeros((image.shape[0], image.shape[1], num_frames - 2, height, width), dtype=image.dtype), last_image], + [ + image, + jnp.zeros( + (image.shape[0], image.shape[1], num_frames - 2, height, width), + dtype=image.dtype, + ), + last_image, + ], axis=2, ) @@ -679,7 +708,11 @@ def _create_common_components(cls, config, vae_only=False, i2v=False): with vae_mesh: wan_vae, vae_cache = cls.load_vae( - devices_array=devices_array, mesh=vae_mesh, rngs=rngs, config=config, vae_logical_axis_rules=vae_logical_axis_rules + devices_array=devices_array, + mesh=vae_mesh, + rngs=rngs, + config=config, + vae_logical_axis_rules=vae_logical_axis_rules, ) components = { @@ -703,7 +736,10 @@ def _create_common_components(cls, config, vae_only=False, i2v=False): components["text_encoder"] = cls.load_text_encoder(config=config) components["scheduler"], components["scheduler_state"] = cls.load_scheduler(config=config) if i2v and config.model_name == "wan2.1": - components["image_processor"], components["image_encoder"] = cls.load_image_encoder(config) + ( + components["image_processor"], + components["image_encoder"], + ) = cls.load_image_encoder(config) return components @abstractmethod @@ -836,10 +872,18 @@ def _prepare_model_inputs( negative_prompt_embeds = jax.device_put(negative_prompt_embeds, data_sharding) scheduler_state = self.scheduler.set_timesteps( - self.scheduler_state, num_inference_steps=num_inference_steps, shape=latents.shape + self.scheduler_state, + num_inference_steps=num_inference_steps, + shape=latents.shape, ) - return latents, prompt_embeds, negative_prompt_embeds, scheduler_state, num_frames + return ( + latents, + prompt_embeds, + negative_prompt_embeds, + scheduler_state, + num_frames, + ) @abstractmethod def __call__(self, **kwargs): @@ -847,7 +891,15 @@ def __call__(self, **kwargs): pass -@partial(jax.jit, static_argnames=("do_classifier_free_guidance", "guidance_scale", "return_residual", "skip_blocks")) +@partial( + jax.jit, + static_argnames=( + "do_classifier_free_guidance", + "guidance_scale", + "return_residual", + "skip_blocks", + ), +) def transformer_forward_pass( graphdef, sharded_state, @@ -861,6 +913,9 @@ def transformer_forward_pass( skip_blocks=None, cached_residual=None, return_residual=False, + kv_cache=None, + rotary_emb=None, + encoder_attention_mask=None, ): wan_transformer = nnx.merge(graphdef, sharded_state, rest_of_state) outputs = wan_transformer( @@ -871,6 +926,9 @@ def transformer_forward_pass( skip_blocks=skip_blocks, cached_residual=cached_residual, return_residual=return_residual, + kv_cache=kv_cache, + rotary_emb=rotary_emb, + encoder_attention_mask=encoder_attention_mask, ) if return_residual: @@ -901,6 +959,9 @@ def transformer_forward_pass_full_cfg( prompt_embeds_combined: jnp.array, guidance_scale: float, encoder_hidden_states_image=None, + kv_cache=None, + rotary_emb=None, + encoder_attention_mask=None, ): """Full CFG forward pass. @@ -919,6 +980,9 @@ def transformer_forward_pass_full_cfg( skip_blocks=False, cached_residual=None, return_residual=False, + kv_cache=kv_cache, + rotary_emb=rotary_emb, + encoder_attention_mask=encoder_attention_mask, ) noise_cond = noise_pred[:bsz] noise_uncond = noise_pred[bsz:] @@ -940,6 +1004,9 @@ def transformer_forward_pass_cfg_cache( w1: float = 1.0, w2: float = 1.0, encoder_hidden_states_image=None, + kv_cache=None, + rotary_emb=None, + encoder_attention_mask=None, ): """CFG-Cache forward pass with FFT frequency-domain compensation. @@ -965,6 +1032,9 @@ def transformer_forward_pass_cfg_cache( timestep=timestep_cond, encoder_hidden_states=prompt_cond_embeds, encoder_hidden_states_image=encoder_hidden_states_image, + kv_cache=kv_cache, + rotary_emb=rotary_emb, + encoder_attention_mask=encoder_attention_mask, ) # FFT over spatial dims (H, W) — last 2 dims of [B, C, F, H, W] diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline_2_1.py b/src/maxdiffusion/pipelines/wan/wan_pipeline_2_1.py index e0a2f05e6..2cb27552c 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline_2_1.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline_2_1.py @@ -73,7 +73,13 @@ def from_pretrained(cls, config: HyperParameters, vae_only=False, load_transform return pipeline @classmethod - def from_checkpoint(cls, config: HyperParameters, restored_checkpoint=None, vae_only=False, load_transformer=True): + def from_checkpoint( + cls, + config: HyperParameters, + restored_checkpoint=None, + vae_only=False, + load_transformer=True, + ): pipeline, _ = cls._load_and_init(config, restored_checkpoint, vae_only, load_transformer) return pipeline @@ -100,6 +106,7 @@ def __call__( magcache_thresh: Optional[float] = None, magcache_K: Optional[int] = None, retention_ratio: Optional[float] = None, + use_kv_cache: bool = False, ): config = getattr(self, "config", None) if magcache_thresh is None: @@ -114,7 +121,6 @@ def __call__( f"use_cfg_cache=True requires guidance_scale > 1.0 (got {guidance_scale}). " "CFG cache accelerates classifier-free guidance, which is disabled when guidance_scale <= 1.0." ) - trace = {} t_cond_start = time.perf_counter() @@ -152,6 +158,7 @@ def __call__( height=height, mag_ratios_base=getattr(config, "mag_ratios_base", None), config=self.config, + use_kv_cache=use_kv_cache, ) t_denoise_start = time.perf_counter() @@ -196,6 +203,7 @@ def run_inference_2_1( height: int = 480, mag_ratios_base: Optional[List[float]] = None, config=None, + use_kv_cache: bool = False, ): """Denoising loop for WAN 2.1 T2V with FasterCache CFG-Cache. @@ -267,6 +275,26 @@ def run_inference_2_1( cached_noise_cond = None cached_noise_uncond = None + transformer_obj = nnx.merge(graphdef, sharded_state, rest_of_state) + + # Compute RoPE once as it only depends on shape + dummy_hidden_states = jnp.zeros(( + latents.shape[0], + latents.shape[2], + latents.shape[3], + latents.shape[4], + latents.shape[1], + )) + rotary_emb = transformer_obj.rope(dummy_hidden_states) + + kv_cache = None + encoder_attention_mask = None + + if use_kv_cache: + kv_cache, encoder_attention_mask = transformer_obj.compute_kv_cache( + prompt_embeds_combined if do_cfg else prompt_cond_embeds + ) + if use_magcache and do_cfg: magcache_init = init_magcache(num_inference_steps, retention_ratio, mag_ratios_base) accumulated_state = magcache_init[:6] @@ -276,7 +304,11 @@ def run_inference_2_1( first_profiling_step = config.skip_first_n_steps_for_profiler if config else 0 profiler_steps = config.profiler_steps if config else 0 - last_profiling_step = np.clip(first_profiling_step + profiler_steps - 1, first_profiling_step, num_inference_steps - 1) + last_profiling_step = np.clip( + first_profiling_step + profiler_steps - 1, + first_profiling_step, + num_inference_steps - 1, + ) scan_diffusion_loop = getattr(config, "scan_diffusion_loop", False) if config else False @@ -338,7 +370,12 @@ def scan_body(carry, t): timestep = jnp.broadcast_to(t, bsz * 2 if do_cfg else bsz) skip_blocks, accumulated_state = magcache_step( - step, mag_ratios, accumulated_state, magcache_thresh, magcache_K, skip_warmup + step, + mag_ratios, + accumulated_state, + magcache_thresh, + magcache_K, + skip_warmup, ) noise_pred, latents, residual_x_cur = transformer_forward_pass( @@ -353,6 +390,9 @@ def scan_body(carry, t): skip_blocks=bool(skip_blocks), cached_residual=cached_residual, return_residual=True, + kv_cache=kv_cache, + rotary_emb=rotary_emb, + encoder_attention_mask=encoder_attention_mask, ) if not skip_blocks: @@ -364,6 +404,8 @@ def scan_body(carry, t): if is_cache_step: w1, w2 = step_w1w2[step] timestep = jnp.broadcast_to(t, bsz) + kv_cache_cond = jax.tree_map(lambda x: x[:, :bsz], kv_cache) if kv_cache is not None else None + encoder_attention_mask_cond = encoder_attention_mask[:bsz] if encoder_attention_mask is not None else None noise_pred, cached_noise_cond = transformer_forward_pass_cfg_cache( graphdef, sharded_state, @@ -376,12 +418,19 @@ def scan_body(carry, t): guidance_scale=guidance_scale, w1=jnp.float32(w1), w2=jnp.float32(w2), + kv_cache=kv_cache_cond, + rotary_emb=rotary_emb, + encoder_attention_mask=encoder_attention_mask_cond, ) elif do_cfg: latents_doubled = jnp.concatenate([latents] * 2) timestep = jnp.broadcast_to(t, bsz * 2) - noise_pred, cached_noise_cond, cached_noise_uncond = transformer_forward_pass_full_cfg( + ( + noise_pred, + cached_noise_cond, + cached_noise_uncond, + ) = transformer_forward_pass_full_cfg( graphdef, sharded_state, rest_of_state, @@ -389,6 +438,9 @@ def scan_body(carry, t): timestep, prompt_embeds_combined, guidance_scale=guidance_scale, + kv_cache=kv_cache, + rotary_emb=rotary_emb, + encoder_attention_mask=encoder_attention_mask, ) else: @@ -402,6 +454,9 @@ def scan_body(carry, t): prompt_cond_embeds, do_classifier_free_guidance=False, guidance_scale=guidance_scale, + kv_cache=kv_cache, + rotary_emb=rotary_emb, + encoder_attention_mask=encoder_attention_mask, ) latents, scheduler_state = scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple() diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline_2_2.py b/src/maxdiffusion/pipelines/wan/wan_pipeline_2_2.py index 77331d66d..b76d60d9a 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline_2_2.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline_2_2.py @@ -89,7 +89,13 @@ def from_pretrained(cls, config: HyperParameters, vae_only=False, load_transform return pipeline @classmethod - def from_checkpoint(cls, config: HyperParameters, restored_checkpoint=None, vae_only=False, load_transformer=True): + def from_checkpoint( + cls, + config: HyperParameters, + restored_checkpoint=None, + vae_only=False, + load_transformer=True, + ): pipeline, low_noise_transformer, high_noise_transformer = cls._load_and_init( config, restored_checkpoint, vae_only, load_transformer ) @@ -116,6 +122,7 @@ def __call__( vae_only: bool = False, use_cfg_cache: bool = False, use_sen_cache: bool = False, + use_kv_cache: bool = False, ): if use_cfg_cache and use_sen_cache: raise ValueError("use_cfg_cache and use_sen_cache are mutually exclusive. Enable only one.") @@ -137,7 +144,13 @@ def __call__( trace = {} t_cond_start = time.perf_counter() - latents, prompt_embeds, negative_prompt_embeds, scheduler_state, num_frames = self._prepare_model_inputs( + ( + latents, + prompt_embeds, + negative_prompt_embeds, + scheduler_state, + num_frames, + ) = self._prepare_model_inputs( prompt, negative_prompt, height, @@ -171,6 +184,7 @@ def __call__( use_cfg_cache=use_cfg_cache, use_sen_cache=use_sen_cache, height=height, + use_kv_cache=use_kv_cache, ) t_denoise_start = time.perf_counter() @@ -220,6 +234,7 @@ def run_inference_2_2( use_sen_cache: bool = False, height: int = 480, config=None, + use_kv_cache: bool = False, ): """Denoising loop for WAN 2.2 T2V with optional caching acceleration. @@ -239,6 +254,33 @@ def run_inference_2_2( do_classifier_free_guidance = guidance_scale_low > 1.0 or guidance_scale_high > 1.0 bsz = latents.shape[0] + prompt_embeds_combined = ( + jnp.concatenate([prompt_embeds, negative_prompt_embeds], axis=0) if do_classifier_free_guidance else prompt_embeds + ) + + low_transformer = nnx.merge(low_noise_graphdef, low_noise_state, low_noise_rest) + + # Compute RoPE once as it only depends on shape + dummy_hidden_states = jnp.zeros(( + latents.shape[0], + latents.shape[2], + latents.shape[3], + latents.shape[4], + latents.shape[1], + )) + rotary_emb = low_transformer.rope(dummy_hidden_states) + + kv_cache_low = None + encoder_attention_mask_low = None + kv_cache_high = None + encoder_attention_mask_high = None + + if use_kv_cache: + kv_cache_low, encoder_attention_mask_low = low_transformer.compute_kv_cache(prompt_embeds_combined) + + high_transformer = nnx.merge(high_noise_graphdef, high_noise_state, high_noise_rest) + kv_cache_high, encoder_attention_mask_high = high_transformer.compute_kv_cache(prompt_embeds_combined) + # ── SenCache path (arXiv:2602.24208) ── if use_sen_cache and do_classifier_free_guidance: timesteps_np = np.array(scheduler_state.timesteps, dtype=np.int32) @@ -262,8 +304,6 @@ def run_inference_2_2( # uses sigmas in [0, 1]. Without normalization |Δt|≈20 >> ε and nothing caches. num_train_timesteps = float(scheduler.config.num_train_timesteps) - prompt_embeds_combined = jnp.concatenate([prompt_embeds, negative_prompt_embeds], axis=0) - # SenCache state ref_noise_pred = None # y^r: cached denoiser output ref_latent = None # x^r: latent at last cache refresh @@ -279,11 +319,23 @@ def run_inference_2_2( # Select transformer and guidance scale if step_uses_high[step]: - graphdef, state, rest = high_noise_graphdef, high_noise_state, high_noise_rest + graphdef, state, rest = ( + high_noise_graphdef, + high_noise_state, + high_noise_rest, + ) guidance_scale = guidance_scale_high + kv_cache = kv_cache_high + encoder_attention_mask = encoder_attention_mask_high else: - graphdef, state, rest = low_noise_graphdef, low_noise_state, low_noise_rest + graphdef, state, rest = ( + low_noise_graphdef, + low_noise_state, + low_noise_rest, + ) guidance_scale = guidance_scale_low + kv_cache = kv_cache_low + encoder_attention_mask = encoder_attention_mask_low # Force full compute: warmup, first 30%, last 10%, or transformer boundary is_boundary = step > 0 and step_uses_high[step] != step_uses_high[step - 1] @@ -302,6 +354,9 @@ def run_inference_2_2( timestep, prompt_embeds_combined, guidance_scale=guidance_scale, + kv_cache=kv_cache, + rotary_emb=rotary_emb, + encoder_attention_mask=encoder_attention_mask, ) ref_noise_pred = noise_pred ref_latent = latents @@ -338,6 +393,9 @@ def run_inference_2_2( timestep, prompt_embeds_combined, guidance_scale=guidance_scale, + kv_cache=kv_cache, + rotary_emb=rotary_emb, + encoder_attention_mask=encoder_attention_mask, ) ref_noise_pred = noise_pred ref_latent = latents @@ -374,7 +432,6 @@ def run_inference_2_2( # Pre-split embeds once prompt_cond_embeds = prompt_embeds - prompt_embeds_combined = jnp.concatenate([prompt_embeds, negative_prompt_embeds], axis=0) # Determine the first low-noise step (boundary transition). # In Wan 2.2 the boundary IS the structural→detail transition, so @@ -420,16 +477,30 @@ def run_inference_2_2( # Select transformer and guidance scale based on precomputed schedule if step_uses_high[step]: - graphdef, state, rest = high_noise_graphdef, high_noise_state, high_noise_rest + graphdef, state, rest = ( + high_noise_graphdef, + high_noise_state, + high_noise_rest, + ) guidance_scale = guidance_scale_high + kv_cache = kv_cache_high + encoder_attention_mask = encoder_attention_mask_high else: - graphdef, state, rest = low_noise_graphdef, low_noise_state, low_noise_rest + graphdef, state, rest = ( + low_noise_graphdef, + low_noise_state, + low_noise_rest, + ) guidance_scale = guidance_scale_low + kv_cache = kv_cache_low + encoder_attention_mask = encoder_attention_mask_low if is_cache_step: # ── Cache step: cond-only forward + FFT frequency compensation ── w1, w2 = step_w1w2[step] timestep = jnp.broadcast_to(t, bsz) + kv_cache_cond = jax.tree_map(lambda x: x[:, :bsz], kv_cache) if kv_cache is not None else None + encoder_attention_mask_cond = encoder_attention_mask[:bsz] if encoder_attention_mask is not None else None noise_pred, cached_noise_cond = transformer_forward_pass_cfg_cache( graphdef, state, @@ -442,12 +513,19 @@ def run_inference_2_2( guidance_scale=guidance_scale, w1=jnp.float32(w1), w2=jnp.float32(w2), + kv_cache=kv_cache_cond, + rotary_emb=rotary_emb, + encoder_attention_mask=encoder_attention_mask_cond, ) else: # ── Full CFG step: doubled batch, store raw cond/uncond for cache ── latents_doubled = jnp.concatenate([latents] * 2) timestep = jnp.broadcast_to(t, bsz * 2) - noise_pred, cached_noise_cond, cached_noise_uncond = transformer_forward_pass_full_cfg( + ( + noise_pred, + cached_noise_cond, + cached_noise_uncond, + ) = transformer_forward_pass_full_cfg( graphdef, state, rest, @@ -455,6 +533,9 @@ def run_inference_2_2( timestep, prompt_embeds_combined, guidance_scale=guidance_scale, + kv_cache=kv_cache, + rotary_emb=rotary_emb, + encoder_attention_mask=encoder_attention_mask, ) latents, scheduler_state = scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple() @@ -467,13 +548,13 @@ def run_inference_2_2( timesteps_np = np.array(scheduler_state.timesteps, dtype=np.int32) step_uses_high = [bool(timesteps_np[s] >= boundary) for s in range(num_inference_steps)] - prompt_embeds_combined = ( - jnp.concatenate([prompt_embeds, negative_prompt_embeds], axis=0) if do_classifier_free_guidance else prompt_embeds - ) - first_profiling_step = config.skip_first_n_steps_for_profiler if config else 0 profiler_steps = config.profiler_steps if config else 0 - last_profiling_step = np.clip(first_profiling_step + profiler_steps - 1, first_profiling_step, num_inference_steps - 1) + last_profiling_step = np.clip( + first_profiling_step + profiler_steps - 1, + first_profiling_step, + num_inference_steps - 1, + ) scan_diffusion_loop = getattr(config, "scan_diffusion_loop", False) if config else False @@ -543,11 +624,19 @@ def scan_body(carry, t): t = jnp.array(scheduler_state.timesteps, dtype=jnp.int32)[step] if step_uses_high[step]: - graphdef, state, rest = high_noise_graphdef, high_noise_state, high_noise_rest + graphdef, state, rest = ( + high_noise_graphdef, + high_noise_state, + high_noise_rest, + ) guidance_scale = guidance_scale_high + kv_cache = kv_cache_high + encoder_attention_mask = encoder_attention_mask_high else: graphdef, state, rest = low_noise_graphdef, low_noise_state, low_noise_rest guidance_scale = guidance_scale_low + kv_cache = kv_cache_low + encoder_attention_mask = encoder_attention_mask_low if do_classifier_free_guidance: latents_doubled = jnp.concatenate([latents] * 2) @@ -560,6 +649,9 @@ def scan_body(carry, t): timestep, prompt_embeds_combined, guidance_scale=guidance_scale, + kv_cache=kv_cache, + rotary_emb=rotary_emb, + encoder_attention_mask=encoder_attention_mask, ) else: timestep = jnp.broadcast_to(t, bsz) @@ -572,6 +664,9 @@ def scan_body(carry, t): prompt_embeds, do_classifier_free_guidance, guidance_scale, + kv_cache=kv_cache, + rotary_emb=rotary_emb, + encoder_attention_mask=encoder_attention_mask, ) latents, scheduler_state = scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple() diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p1.py b/src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p1.py index 0abe4fa5b..8711b493d 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p1.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p1.py @@ -77,7 +77,13 @@ def from_pretrained(cls, config: HyperParameters, vae_only=False, load_transform return pipeline @classmethod - def from_checkpoint(cls, config: HyperParameters, restored_checkpoint=None, vae_only=False, load_transformer=True): + def from_checkpoint( + cls, + config: HyperParameters, + restored_checkpoint=None, + vae_only=False, + load_transformer=True, + ): pipeline, _ = cls._load_and_init(config, restored_checkpoint, vae_only, load_transformer) return pipeline @@ -113,7 +119,13 @@ def prepare_latents( latent_height = height // self.vae_scale_factor_spatial latent_width = width // self.vae_scale_factor_spatial - shape = (batch_size, num_latent_frames, latent_height, latent_width, num_channels_latents) + shape = ( + batch_size, + num_latent_frames, + latent_height, + latent_width, + num_channels_latents, + ) if latents is None: latents = jax.random.normal(rng, shape=shape, dtype=jnp.float32) @@ -129,7 +141,12 @@ def prepare_latents( first_frame_mask = jnp.repeat(first_frame_mask, self.vae_scale_factor_temporal, axis=2) mask_lat_size = jnp.concatenate([first_frame_mask, mask_lat_size[:, :, 1:]], axis=2) mask_lat_size = mask_lat_size.reshape( - batch_size, 1, num_latent_frames, self.vae_scale_factor_temporal, latent_height, latent_width + batch_size, + 1, + num_latent_frames, + self.vae_scale_factor_temporal, + latent_height, + latent_width, ) mask_lat_size = jnp.transpose(mask_lat_size, (0, 2, 4, 5, 3, 1)).squeeze(-1) condition = jnp.concatenate([mask_lat_size, latent_condition], axis=-1) @@ -158,6 +175,7 @@ def __call__( magcache_thresh: Optional[float] = None, magcache_K: Optional[int] = None, retention_ratio: Optional[float] = None, + use_kv_cache: bool = False, ): config = getattr(self, "config", None) if magcache_thresh is None: @@ -180,7 +198,6 @@ def __call__( num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1 max_logging.log(f"Adjusted num_frames to: {num_frames}") num_frames = max(num_frames, 1) - trace = {} t_cond_start = time.perf_counter() @@ -231,7 +248,9 @@ def _process_image_input(img_input, height, width, num_videos_per_prompt): trace["conditioning"] = time.perf_counter() - t_cond_start scheduler_state = self.scheduler.set_timesteps( - self.scheduler_state, num_inference_steps=num_inference_steps, shape=latents.shape + self.scheduler_state, + num_inference_steps=num_inference_steps, + shape=latents.shape, ) graphdef, state, rest_of_state = nnx.split(self.transformer, nnx.Param, ...) @@ -262,6 +281,7 @@ def _process_image_input(img_input, height, width, num_videos_per_prompt): height=height, mag_ratios_base=self.config.mag_ratios_base_720p if height >= 720 else self.config.mag_ratios_base_480p, config=self.config, + use_kv_cache=use_kv_cache, ) t_denoise_start = time.perf_counter() @@ -311,6 +331,7 @@ def run_inference_2_1_i2v( height: int = 480, mag_ratios_base: Optional[List[float]] = None, config=None, + use_kv_cache: bool = False, ): do_cfg = guidance_scale > 1.0 @@ -330,9 +351,25 @@ def run_inference_2_1_i2v( image_embeds_combined = image_embeds condition_combined = condition + transformer_obj = nnx.merge(graphdef, sharded_state, rest_of_state) + + # Compute RoPE once as it only depends on shape + dummy_hidden_states = jnp.zeros(latents.shape) + rotary_emb = transformer_obj.rope(dummy_hidden_states) + + kv_cache = None + encoder_attention_mask = None + + if use_kv_cache: + kv_cache, encoder_attention_mask = transformer_obj.compute_kv_cache(prompt_embeds_combined, image_embeds_combined) + first_profiling_step = config.skip_first_n_steps_for_profiler if config else 0 profiler_steps = config.profiler_steps if config else 0 - last_profiling_step = np.clip(first_profiling_step + profiler_steps - 1, first_profiling_step, num_inference_steps - 1) + last_profiling_step = np.clip( + first_profiling_step + profiler_steps - 1, + first_profiling_step, + num_inference_steps - 1, + ) scan_diffusion_loop = getattr(config, "scan_diffusion_loop", False) if config else False @@ -393,7 +430,12 @@ def scan_body(carry, t): skip_blocks = False if use_magcache and do_cfg: skip_blocks, accumulated_state = magcache_step( - step, mag_ratios, accumulated_state, magcache_thresh, magcache_K, skip_warmup + step, + mag_ratios, + accumulated_state, + magcache_thresh, + magcache_K, + skip_warmup, ) latents_input = latents @@ -417,6 +459,9 @@ def scan_body(carry, t): skip_blocks=bool(skip_blocks) if use_magcache and do_cfg else None, cached_residual=cached_residual if use_magcache and do_cfg else None, return_residual=True if use_magcache and do_cfg else False, + kv_cache=kv_cache, + rotary_emb=rotary_emb, + encoder_attention_mask=encoder_attention_mask, ) if use_magcache and do_cfg: noise_pred, _, residual_x_cur = outputs diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p2.py b/src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p2.py index f466ec574..9944e2dc2 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p2.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline_i2v_2p2.py @@ -95,7 +95,13 @@ def from_pretrained(cls, config: HyperParameters, vae_only=False, load_transform return pipeline @classmethod - def from_checkpoint(cls, config: HyperParameters, restored_checkpoint=None, vae_only=False, load_transformer=True): + def from_checkpoint( + cls, + config: HyperParameters, + restored_checkpoint=None, + vae_only=False, + load_transformer=True, + ): pipeline, _, _ = cls._load_and_init(config, restored_checkpoint, vae_only, load_transformer) return pipeline @@ -126,7 +132,13 @@ def prepare_latents( latent_height = height // self.vae_scale_factor_spatial latent_width = width // self.vae_scale_factor_spatial - shape = (batch_size, num_latent_frames, latent_height, latent_width, num_channels_latents) + shape = ( + batch_size, + num_latent_frames, + latent_height, + latent_width, + num_channels_latents, + ) if latents is None: latents = jax.random.normal(rng, shape=shape, dtype=jnp.float32) @@ -144,7 +156,12 @@ def prepare_latents( first_frame_mask = jnp.repeat(first_frame_mask, self.vae_scale_factor_temporal, axis=2) mask_lat_size = jnp.concatenate([first_frame_mask, mask_lat_size[:, :, 1:]], axis=2) mask_lat_size = mask_lat_size.reshape( - batch_size, 1, num_latent_frames, self.vae_scale_factor_temporal, latent_height, latent_width + batch_size, + 1, + num_latent_frames, + self.vae_scale_factor_temporal, + latent_height, + latent_width, ) mask_lat_size = jnp.transpose(mask_lat_size, (0, 2, 4, 5, 3, 1)).squeeze(-1) condition = jnp.concatenate([mask_lat_size, latent_condition], axis=-1) @@ -172,6 +189,7 @@ def __call__( rng: Optional[jax.Array] = None, use_cfg_cache: bool = False, use_sen_cache: bool = False, + use_kv_cache: bool = False, ): if use_cfg_cache and use_sen_cache: raise ValueError("use_cfg_cache and use_sen_cache are mutually exclusive. Enable only one.") @@ -202,7 +220,6 @@ def __call__( num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1 max_logging.log(f"Adjusted num_frames to: {num_frames}") num_frames = max(num_frames, 1) - trace = {} t_cond_start = time.perf_counter() @@ -256,7 +273,9 @@ def _process_image_input(img_input, height, width, num_videos_per_prompt): trace["conditioning"] = time.perf_counter() - t_cond_start scheduler_state = self.scheduler.set_timesteps( - self.scheduler_state, num_inference_steps=num_inference_steps, shape=latents.shape + self.scheduler_state, + num_inference_steps=num_inference_steps, + shape=latents.shape, ) low_noise_graphdef, low_noise_state, low_noise_rest = nnx.split(self.low_noise_transformer, nnx.Param, ...) @@ -288,6 +307,7 @@ def _process_image_input(img_input, height, width, num_videos_per_prompt): use_sen_cache=use_sen_cache, height=height, config=self.config, + use_kv_cache=use_kv_cache, ) t_denoise_start = time.perf_counter() @@ -344,10 +364,42 @@ def run_inference_2_2_i2v( use_sen_cache: bool = False, height: int = 480, config=None, + use_kv_cache: bool = False, ): do_classifier_free_guidance = guidance_scale_low > 1.0 or guidance_scale_high > 1.0 bsz = latents.shape[0] + prompt_embeds_combined = ( + jnp.concatenate([prompt_embeds, negative_prompt_embeds], axis=0) if do_classifier_free_guidance else prompt_embeds + ) + if image_embeds is not None: + image_embeds_combined = ( + jnp.concatenate([image_embeds, image_embeds], axis=0) if do_classifier_free_guidance else image_embeds + ) + else: + image_embeds_combined = None + + low_transformer = nnx.merge(low_noise_graphdef, low_noise_state, low_noise_rest) + + # Compute RoPE once as it only depends on shape + dummy_hidden_states = jnp.zeros(latents.shape) + rotary_emb = low_transformer.rope(dummy_hidden_states) + + kv_cache_low = None + encoder_attention_mask_low = None + kv_cache_high = None + encoder_attention_mask_high = None + + if use_kv_cache: + kv_cache_low, encoder_attention_mask_low = low_transformer.compute_kv_cache( + prompt_embeds_combined, image_embeds_combined + ) + + high_transformer = nnx.merge(high_noise_graphdef, high_noise_state, high_noise_rest) + kv_cache_high, encoder_attention_mask_high = high_transformer.compute_kv_cache( + prompt_embeds_combined, image_embeds_combined + ) + # ── SenCache path (arXiv:2602.24208) ── if use_sen_cache and do_classifier_free_guidance: timesteps_np = np.array(scheduler_state.timesteps, dtype=np.int32) @@ -365,11 +417,6 @@ def run_inference_2_2_i2v( nocache_end_begin = int(num_inference_steps * (1.0 - nocache_end_ratio)) num_train_timesteps = float(scheduler.config.num_train_timesteps) - prompt_embeds_combined = jnp.concatenate([prompt_embeds, negative_prompt_embeds], axis=0) - if image_embeds is not None: - image_embeds_combined = jnp.concatenate([image_embeds, image_embeds], axis=0) - else: - image_embeds_combined = None condition_doubled = jnp.concatenate([condition] * 2) # SenCache state @@ -386,11 +433,23 @@ def run_inference_2_2_i2v( t_float = float(timesteps_np[step]) / num_train_timesteps if step_uses_high[step]: - graphdef, state, rest = high_noise_graphdef, high_noise_state, high_noise_rest + graphdef, state, rest = ( + high_noise_graphdef, + high_noise_state, + high_noise_rest, + ) guidance_scale = guidance_scale_high + kv_cache = kv_cache_high + encoder_attention_mask = encoder_attention_mask_high else: - graphdef, state, rest = low_noise_graphdef, low_noise_state, low_noise_rest + graphdef, state, rest = ( + low_noise_graphdef, + low_noise_state, + low_noise_rest, + ) guidance_scale = guidance_scale_low + kv_cache = kv_cache_low + encoder_attention_mask = encoder_attention_mask_low is_boundary = step > 0 and step_uses_high[step] != step_uses_high[step - 1] force_compute = ( @@ -411,6 +470,9 @@ def run_inference_2_2_i2v( prompt_embeds_combined, guidance_scale=guidance_scale, encoder_hidden_states_image=image_embeds_combined, + kv_cache=kv_cache, + rotary_emb=rotary_emb, + encoder_attention_mask=encoder_attention_mask, ) noise_pred = jnp.transpose(noise_pred, (0, 2, 3, 4, 1)) ref_noise_pred = noise_pred @@ -447,6 +509,9 @@ def run_inference_2_2_i2v( prompt_embeds_combined, guidance_scale=guidance_scale, encoder_hidden_states_image=image_embeds_combined, + kv_cache=kv_cache, + rotary_emb=rotary_emb, + encoder_attention_mask=encoder_attention_mask, ) noise_pred = jnp.transpose(noise_pred, (0, 2, 3, 4, 1)) ref_noise_pred = noise_pred @@ -483,14 +548,11 @@ def run_inference_2_2_i2v( # Pre-split embeds prompt_cond_embeds = prompt_embeds - prompt_embeds_combined = jnp.concatenate([prompt_embeds, negative_prompt_embeds], axis=0) if image_embeds is not None: image_embeds_cond = image_embeds - image_embeds_combined = jnp.concatenate([image_embeds, image_embeds], axis=0) else: image_embeds_cond = None - image_embeds_combined = None # Keep condition in both single and doubled forms condition_cond = condition @@ -534,11 +596,23 @@ def run_inference_2_2_i2v( is_cache_step = step_is_cache[step] if step_uses_high[step]: - graphdef, state, rest = high_noise_graphdef, high_noise_state, high_noise_rest + graphdef, state, rest = ( + high_noise_graphdef, + high_noise_state, + high_noise_rest, + ) guidance_scale = guidance_scale_high + kv_cache = kv_cache_high + encoder_attention_mask = encoder_attention_mask_high else: - graphdef, state, rest = low_noise_graphdef, low_noise_state, low_noise_rest + graphdef, state, rest = ( + low_noise_graphdef, + low_noise_state, + low_noise_rest, + ) guidance_scale = guidance_scale_low + kv_cache = kv_cache_low + encoder_attention_mask = encoder_attention_mask_low if is_cache_step: # ── Cache step: cond-only forward + FFT frequency compensation ── @@ -547,6 +621,8 @@ def run_inference_2_2_i2v( latent_model_input = jnp.concatenate([latents, condition_cond], axis=-1) latent_model_input = jnp.transpose(latent_model_input, (0, 4, 1, 2, 3)) timestep = jnp.broadcast_to(t, bsz) + kv_cache_cond = jax.tree_map(lambda x: x[:, :bsz], kv_cache) if kv_cache is not None else None + encoder_attention_mask_cond = encoder_attention_mask[:bsz] if encoder_attention_mask is not None else None noise_pred, cached_noise_cond = transformer_forward_pass_cfg_cache( graphdef, state, @@ -560,6 +636,9 @@ def run_inference_2_2_i2v( w1=jnp.float32(w1), w2=jnp.float32(w2), encoder_hidden_states_image=image_embeds_cond, + kv_cache=kv_cache_cond, + rotary_emb=rotary_emb, + encoder_attention_mask=encoder_attention_mask_cond, ) else: # ── Full CFG step: doubled batch, store raw cond/uncond for cache ── @@ -567,7 +646,11 @@ def run_inference_2_2_i2v( latent_model_input = jnp.concatenate([latents_doubled, condition_doubled], axis=-1) latent_model_input = jnp.transpose(latent_model_input, (0, 4, 1, 2, 3)) timestep = jnp.broadcast_to(t, bsz * 2) - noise_pred, cached_noise_cond, cached_noise_uncond = transformer_forward_pass_full_cfg( + ( + noise_pred, + cached_noise_cond, + cached_noise_uncond, + ) = transformer_forward_pass_full_cfg( graphdef, state, rest, @@ -576,6 +659,9 @@ def run_inference_2_2_i2v( prompt_embeds_combined, guidance_scale=guidance_scale, encoder_hidden_states_image=image_embeds_combined, + kv_cache=kv_cache, + rotary_emb=rotary_emb, + encoder_attention_mask=encoder_attention_mask, ) noise_pred = jnp.transpose(noise_pred, (0, 2, 3, 4, 1)) # BCFHW -> BFHWC @@ -584,7 +670,17 @@ def run_inference_2_2_i2v( # ── Original non-cache path ── def high_noise_branch(operands): - latents_input, ts_input, pe_input, ie_input = operands + ( + latents_input, + ts_input, + pe_input, + ie_input, + kv_cache_high, + _, + r_emb, + mask_high, + _, + ) = operands latents_input = jnp.transpose(latents_input, (0, 4, 1, 2, 3)) noise_pred, latents_out = transformer_forward_pass( high_noise_graphdef, @@ -596,11 +692,24 @@ def high_noise_branch(operands): do_classifier_free_guidance=do_classifier_free_guidance, guidance_scale=guidance_scale_high, encoder_hidden_states_image=ie_input, + kv_cache=kv_cache_high, + rotary_emb=r_emb, + encoder_attention_mask=mask_high, ) return noise_pred, latents_out def low_noise_branch(operands): - latents_input, ts_input, pe_input, ie_input = operands + ( + latents_input, + ts_input, + pe_input, + ie_input, + _, + kv_cache_low, + r_emb, + _, + mask_low, + ) = operands latents_input = jnp.transpose(latents_input, (0, 4, 1, 2, 3)) noise_pred, latents_out = transformer_forward_pass( low_noise_graphdef, @@ -612,19 +721,22 @@ def low_noise_branch(operands): do_classifier_free_guidance=do_classifier_free_guidance, guidance_scale=guidance_scale_low, encoder_hidden_states_image=ie_input, + kv_cache=kv_cache_low, + rotary_emb=r_emb, + encoder_attention_mask=mask_low, ) return noise_pred, latents_out if do_classifier_free_guidance: - prompt_embeds = jnp.concatenate([prompt_embeds, negative_prompt_embeds], axis=0) - # WAN 2.2 I2V: image_embeds may be None since it doesn't use CLIP image encoder - if image_embeds is not None: - image_embeds = jnp.concatenate([image_embeds, image_embeds], axis=0) condition = jnp.concatenate([condition] * 2) first_profiling_step = config.skip_first_n_steps_for_profiler if config else 0 profiler_steps = config.profiler_steps if config else 0 - last_profiling_step = np.clip(first_profiling_step + profiler_steps - 1, first_profiling_step, num_inference_steps - 1) + last_profiling_step = np.clip( + first_profiling_step + profiler_steps - 1, + first_profiling_step, + num_inference_steps - 1, + ) scan_diffusion_loop = getattr(config, "scan_diffusion_loop", False) if config else False @@ -675,7 +787,20 @@ def scan_body(carry, t): use_high_noise = jnp.greater_equal(t, boundary) noise_pred, _ = jax.lax.cond( - use_high_noise, high_noise_branch, low_noise_branch, (latent_model_input, timestep, prompt_embeds, image_embeds) + use_high_noise, + high_noise_branch, + low_noise_branch, + ( + latent_model_input, + timestep, + prompt_embeds_combined, + image_embeds_combined, + kv_cache_high, + kv_cache_low, + rotary_emb, + encoder_attention_mask_high, + encoder_attention_mask_low, + ), ) noise_pred = jnp.transpose(noise_pred, (0, 2, 3, 4, 1)) latents, scheduler_state = scheduler.step(scheduler_state, noise_pred, t, latents).to_tuple()