diff --git a/QEfficient/diffusers/first_block_cache/wan.py b/QEfficient/diffusers/first_block_cache/wan.py index 5cfdd842b7..7ae7551d20 100644 --- a/QEfficient/diffusers/first_block_cache/wan.py +++ b/QEfficient/diffusers/first_block_cache/wan.py @@ -24,6 +24,7 @@ from diffusers.models.modeling_outputs import Transformer2DModelOutput from QEfficient.utils import constants +from QEfficient.utils.logging_utils import logger def _check_similarity(first_block_residuals: torch.Tensor, prev_first_block_residuals: torch.Tensor) -> torch.Tensor: @@ -228,6 +229,7 @@ def run_wan_non_unified_first_block_cache_denoise( callback_on_step_end_tensor_inputs: List[str], cache_threshold_high: Optional[float] = None, cache_threshold_low: Optional[float] = None, + magcache_runtime: Optional[Any] = None, ): """ Cache-aware non-unified WAN denoise loop. @@ -308,12 +310,19 @@ def run_wan_non_unified_first_block_cache_denoise( "cache_threshold": np.array(stage_cache_threshold, dtype=np.float32), } - with current_model.cache_context("cond"): + def _run_first_block_cache_step(stream_name: str, inputs: Dict[str, np.ndarray]) -> torch.Tensor: + if magcache_runtime is not None and magcache_runtime.should_skip(stream_name): + cached_residual = magcache_runtime.get_cached_residual(stream_name) + magcache_runtime.complete_skip(stream_name) + if magcache_runtime.verbose: + logger.info(f"MagCache skip: step={i}, stream={stream_name}, t={float(t):.2f}") + return latents.to(cached_residual.dtype) + cached_residual + start_transformer_step_time = time.perf_counter() - outputs = current_transformer_module.qpc_session.run(inputs_aic) + outputs = current_transformer_module.qpc_session.run(inputs) end_transformer_step_time = time.perf_counter() transformer_perf.append(end_transformer_step_time - start_transformer_step_time) - noise_pred = pipeline._reshape_noise_prediction( + noise_pred_step = pipeline._reshape_noise_prediction( outputs, batch_size, post_patch_num_frames, @@ -324,22 +333,19 @@ def run_wan_non_unified_first_block_cache_denoise( p_w, ) + if magcache_runtime is not None: + residual = noise_pred_step - latents.to(noise_pred_step.dtype) + magcache_runtime.complete_call(stream_name, residual) + if magcache_runtime.verbose: + logger.info(f"MagCache run: step={i}, stream={stream_name}, t={float(t):.2f}") + return noise_pred_step + + with current_model.cache_context("cond"): + noise_pred = _run_first_block_cache_step("cond", inputs_aic) + if pipeline.do_classifier_free_guidance: with current_model.cache_context("uncond"): - start_transformer_step_time = time.perf_counter() - outputs = current_transformer_module.qpc_session.run(inputs_aic2) - end_transformer_step_time = time.perf_counter() - transformer_perf.append(end_transformer_step_time - start_transformer_step_time) - noise_uncond = pipeline._reshape_noise_prediction( - outputs, - batch_size, - post_patch_num_frames, - post_patch_height, - post_patch_width, - p_t, - p_h, - p_w, - ) + noise_uncond = _run_first_block_cache_step("uncond", inputs_aic2) noise_pred = noise_uncond + current_guidance_scale * (noise_pred - noise_uncond) latents = pipeline.scheduler.step(noise_pred, t, latents, return_dict=False)[0] diff --git a/QEfficient/diffusers/pipelines/wan/magcache.py b/QEfficient/diffusers/pipelines/wan/magcache.py new file mode 100644 index 0000000000..e0995930ac --- /dev/null +++ b/QEfficient/diffusers/pipelines/wan/magcache.py @@ -0,0 +1,286 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ---------------------------------------------------------------------------- +"""Runtime MagCache helpers for WAN pipelines. + +This module implements a pipeline-level (graph-agnostic) MagCache controller. +It does not modify ONNX/QPC graph signatures. +""" + +from __future__ import annotations + +from dataclasses import dataclass, field +from typing import Dict, Optional, Sequence + +import numpy as np +import torch + +# Wan2.2 T2V-A14B mag ratios from MagCache4Wan2.2. +DEFAULT_WAN2_2_T2V_A14B_MAG_RATIOS = [ + 1.00124, + 1.00155, + 0.99822, + 0.99851, + 0.99696, + 0.99687, + 0.99703, + 0.99732, + 0.9966, + 0.99679, + 0.99602, + 0.99658, + 0.99578, + 0.99664, + 0.99484, + 0.9949, + 0.99633, + 0.996, + 0.99659, + 0.99683, + 0.99534, + 0.99549, + 0.99584, + 0.99577, + 0.99681, + 0.99694, + 0.99563, + 0.99554, + 0.9944, + 0.99473, + 0.99594, + 0.9964, + 0.99466, + 0.99461, + 0.99453, + 0.99481, + 0.99389, + 0.99365, + 0.99391, + 0.99406, + 0.99354, + 0.99361, + 0.99283, + 0.99278, + 0.99268, + 0.99263, + 0.99057, + 0.99091, + 0.99125, + 0.99126, + 0.65523, + 0.65252, + 0.98808, + 0.98852, + 0.98765, + 0.98736, + 0.9851, + 0.98535, + 0.98311, + 0.98339, + 0.9805, + 0.9806, + 0.97776, + 0.97771, + 0.97278, + 0.97286, + 0.96731, + 0.96728, + 0.95857, + 0.95855, + 0.94385, + 0.94385, + 0.92118, + 0.921, + 0.88108, + 0.88076, + 0.80263, + 0.80181, +] + + +def nearest_interp(src_array: np.ndarray, target_length: int) -> np.ndarray: + """Nearest-neighbor interpolation used by the upstream MagCache scripts.""" + src_length = len(src_array) + if target_length == 1: + return np.array([src_array[-1]], dtype=np.float32) + + scale = (src_length - 1) / (target_length - 1) + mapped_indices = np.round(np.arange(target_length) * scale).astype(int) + return src_array[mapped_indices].astype(np.float32) + + +@dataclass +class _StreamState: + cached_residual: Optional[torch.Tensor] = None + accumulated_ratio: float = 1.0 + accumulated_err: float = 0.0 + accumulated_steps: int = 0 + + def reset_accumulators(self) -> None: + self.accumulated_ratio = 1.0 + self.accumulated_err = 0.0 + self.accumulated_steps = 0 + + def reset_all(self) -> None: + self.cached_residual = None + self.reset_accumulators() + + +@dataclass +class WanMagCacheRuntime: + """Runtime state machine for WAN MagCache. + + This class tracks per-stream state (cond/uncond), applies stage-aware retention + windows, and decides whether to skip a QAIC forward call. + """ + + num_inference_steps: int + do_classifier_free_guidance: bool + threshold: float + max_skip_steps: int + retention_ratio: float + split_step: Optional[int] = None + ratios: Optional[Sequence[float]] = None + verbose: bool = False + + call_index: int = 0 + skipped_calls: int = 0 + executed_calls: int = 0 + stream_states: Dict[str, _StreamState] = field(default_factory=dict) + + def _debug_print(self, message: str) -> None: + if self.verbose: + print(message) + + def __post_init__(self) -> None: + if self.threshold < 0: + raise ValueError(f"`magcache_thresh` must be >= 0, got {self.threshold}.") + if self.max_skip_steps < 0: + raise ValueError(f"`magcache_K` must be >= 0, got {self.max_skip_steps}.") + if not 0.0 <= self.retention_ratio <= 1.0: + raise ValueError(f"`magcache_retention_ratio` must be in [0, 1], got {self.retention_ratio}.") + + self.calls_per_step = 2 if self.do_classifier_free_guidance else 1 + self.total_calls = self.num_inference_steps * self.calls_per_step + self._prepared_ratios = self._prepare_ratios( + self.ratios, + num_steps=self.num_inference_steps, + calls_per_step=self.calls_per_step, + ) + + self.stream_states = {"cond": _StreamState()} + if self.do_classifier_free_guidance: + self.stream_states["uncond"] = _StreamState() + + if self.split_step is not None: + # Convert timestep split to invocation split (cond/uncond aware). + self.split_step = int(self.split_step) * self.calls_per_step + + @staticmethod + def _prepare_ratios( + ratios: Optional[Sequence[float]], + num_steps: int, + calls_per_step: int, + ) -> np.ndarray: + raw = np.asarray( + DEFAULT_WAN2_2_T2V_A14B_MAG_RATIOS if ratios is None else list(ratios), + dtype=np.float32, + ) + + if calls_per_step == 1: + # If user provides interleaved cond/uncond ratios, use cond stream. + if raw.size % 2 == 0 and raw.size > 0: + raw = raw[0::2] + prepared = np.concatenate([np.array([1.0], dtype=np.float32), raw]) + if len(prepared) != num_steps: + prepared = nearest_interp(prepared, num_steps) + return prepared + + prepared = np.concatenate([np.array([1.0, 1.0], dtype=np.float32), raw]) + if len(prepared) != num_steps * 2: + mag_ratio_cond = nearest_interp(prepared[0::2], num_steps) + mag_ratio_uncond = nearest_interp(prepared[1::2], num_steps) + prepared = np.empty(num_steps * 2, dtype=np.float32) + prepared[0::2] = mag_ratio_cond + prepared[1::2] = mag_ratio_uncond + return prepared + + def _cache_allowed_for_call(self, call_index: int) -> bool: + # Single-stage mode (e.g., no high/low split): warmup-only retention window. + if self.split_step is None: + return call_index >= int(self.total_calls * self.retention_ratio) + + # Wan2.2 T2V/I2V-like stage-aware retention scheduling. + retain_high = int(self.split_step * self.retention_ratio) + retain_low_end = int((self.total_calls - self.split_step) * self.retention_ratio + self.split_step) + + if call_index < retain_high: + return False + if self.split_step <= call_index <= retain_low_end: + return False + return True + + def should_skip(self, stream_name: str) -> bool: + state = self.stream_states[stream_name] + + if not self._cache_allowed_for_call(self.call_index): + self._debug_print( + f"[MagCache] call={self.call_index} stream={stream_name} diff=N/A " + f"thresh={self.threshold:.6f} decision=run (retention window)" + ) + return False + if state.cached_residual is None: + self._debug_print( + f"[MagCache] call={self.call_index} stream={stream_name} diff=N/A " + f"thresh={self.threshold:.6f} decision=run (cache cold start)" + ) + return False + + ratio = float(self._prepared_ratios[self.call_index]) + state.accumulated_ratio *= ratio + state.accumulated_steps += 1 + state.accumulated_err += abs(1.0 - state.accumulated_ratio) + + should_skip = state.accumulated_err < self.threshold and state.accumulated_steps <= self.max_skip_steps + self._debug_print( + f"[MagCache] call={self.call_index} stream={stream_name} diff={state.accumulated_err:.6f} " + f"thresh={self.threshold:.6f} k={state.accumulated_steps}/{self.max_skip_steps} " + f"decision={'skip' if should_skip else 'run'}" + ) + if should_skip: + self.skipped_calls += 1 + self._debug_print(f"[MagCache] stream={stream_name} diff<{self.threshold:.6f}; skipping this step for now.") + return True + + state.reset_accumulators() + return False + + def get_cached_residual(self, stream_name: str) -> torch.Tensor: + cached = self.stream_states[stream_name].cached_residual + if cached is None: + raise RuntimeError(f"MagCache residual is empty for stream '{stream_name}'.") + return cached + + def complete_call(self, stream_name: str, residual: torch.Tensor) -> None: + state = self.stream_states[stream_name] + state.cached_residual = residual.detach() + self.executed_calls += 1 + + self.call_index += 1 + if self.call_index >= self.total_calls: + self._reset_for_next_video() + + def complete_skip(self, stream_name: str) -> None: + if stream_name not in self.stream_states: + raise KeyError(f"Unknown stream name '{stream_name}'.") + self.call_index += 1 + if self.call_index >= self.total_calls: + self._reset_for_next_video() + + def _reset_for_next_video(self) -> None: + self.call_index = 0 + for state in self.stream_states.values(): + state.reset_all() diff --git a/QEfficient/diffusers/pipelines/wan/pipeline_wan.py b/QEfficient/diffusers/pipelines/wan/pipeline_wan.py index cd80269ed2..ecb2fa21d9 100644 --- a/QEfficient/diffusers/pipelines/wan/pipeline_wan.py +++ b/QEfficient/diffusers/pipelines/wan/pipeline_wan.py @@ -43,6 +43,7 @@ config_manager, set_execute_params, ) +from QEfficient.diffusers.pipelines.wan.magcache import WanMagCacheRuntime from QEfficient.generation.cloud_infer import QAICInferenceSession from QEfficient.utils import constants from QEfficient.utils.logging_utils import logger @@ -517,6 +518,7 @@ def _run_denoise_loop_unified( callback_on_step_end_tensor_inputs: List[str], cache_threshold_high: Optional[float] = None, cache_threshold_low: Optional[float] = None, + magcache_runtime: Optional[WanMagCacheRuntime] = None, ): transformer_perf = [] with self.model.progress_bar(total=num_inference_steps) as progress_bar: @@ -584,12 +586,19 @@ def _run_denoise_loop_unified( "tsp": model_type.detach().numpy(), } - with current_model.cache_context("cond"): + def _run_unified_step(stream_name: str, inputs: Dict[str, np.ndarray]) -> torch.Tensor: + if magcache_runtime is not None and magcache_runtime.should_skip(stream_name): + cached_residual = magcache_runtime.get_cached_residual(stream_name) + magcache_runtime.complete_skip(stream_name) + if magcache_runtime.verbose: + logger.info(f"MagCache skip: step={i}, stream={stream_name}, t={float(t):.2f}") + return latents.to(cached_residual.dtype) + cached_residual + start_transformer_step_time = time.perf_counter() - outputs = self.transformer.qpc_session.run(inputs_aic) + outputs = self.transformer.qpc_session.run(inputs) end_transformer_step_time = time.perf_counter() transformer_perf.append(end_transformer_step_time - start_transformer_step_time) - noise_pred = self._reshape_noise_prediction( + noise_pred_step = self._reshape_noise_prediction( outputs, batch_size, post_patch_num_frames, @@ -600,22 +609,20 @@ def _run_denoise_loop_unified( p_w, ) + if magcache_runtime is not None: + residual = noise_pred_step - latents.to(noise_pred_step.dtype) + magcache_runtime.complete_call(stream_name, residual) + if magcache_runtime.verbose: + logger.info(f"MagCache run: step={i}, stream={stream_name}, t={float(t):.2f}") + + return noise_pred_step + + with current_model.cache_context("cond"): + noise_pred = _run_unified_step("cond", inputs_aic) + if self.do_classifier_free_guidance: with current_model.cache_context("uncond"): - start_transformer_step_time = time.perf_counter() - outputs = self.transformer.qpc_session.run(inputs_aic2) - end_transformer_step_time = time.perf_counter() - transformer_perf.append(end_transformer_step_time - start_transformer_step_time) - noise_uncond = self._reshape_noise_prediction( - outputs, - batch_size, - post_patch_num_frames, - post_patch_height, - post_patch_width, - p_t, - p_h, - p_w, - ) + noise_uncond = _run_unified_step("uncond", inputs_aic2) noise_pred = noise_uncond + current_guidance_scale * (noise_pred - noise_uncond) latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] @@ -650,6 +657,7 @@ def _run_denoise_loop_non_unified( callback_on_step_end_tensor_inputs: List[str], cache_threshold_high: Optional[float] = None, cache_threshold_low: Optional[float] = None, + magcache_runtime: Optional[WanMagCacheRuntime] = None, ): transformer_perf = [] with self.model.progress_bar(total=num_inference_steps) as progress_bar: @@ -714,12 +722,19 @@ def _run_denoise_loop_non_unified( "timestep_proj": timestep_proj.detach().numpy(), } - with current_model.cache_context("cond"): + def _run_non_unified_step(stream_name: str, inputs: Dict[str, np.ndarray]) -> torch.Tensor: + if magcache_runtime is not None and magcache_runtime.should_skip(stream_name): + cached_residual = magcache_runtime.get_cached_residual(stream_name) + magcache_runtime.complete_skip(stream_name) + if magcache_runtime.verbose: + logger.info(f"MagCache skip: step={i}, stream={stream_name}, t={float(t):.2f}") + return latents.to(cached_residual.dtype) + cached_residual + start_transformer_step_time = time.perf_counter() - outputs = current_transformer_module.qpc_session.run(inputs_aic) + outputs = current_transformer_module.qpc_session.run(inputs) end_transformer_step_time = time.perf_counter() transformer_perf.append(end_transformer_step_time - start_transformer_step_time) - noise_pred = self._reshape_noise_prediction( + noise_pred_step = self._reshape_noise_prediction( outputs, batch_size, post_patch_num_frames, @@ -730,22 +745,20 @@ def _run_denoise_loop_non_unified( p_w, ) + if magcache_runtime is not None: + residual = noise_pred_step - latents.to(noise_pred_step.dtype) + magcache_runtime.complete_call(stream_name, residual) + if magcache_runtime.verbose: + logger.info(f"MagCache run: step={i}, stream={stream_name}, t={float(t):.2f}") + + return noise_pred_step + + with current_model.cache_context("cond"): + noise_pred = _run_non_unified_step("cond", inputs_aic) + if self.do_classifier_free_guidance: with current_model.cache_context("uncond"): - start_transformer_step_time = time.perf_counter() - outputs = current_transformer_module.qpc_session.run(inputs_aic2) - end_transformer_step_time = time.perf_counter() - transformer_perf.append(end_transformer_step_time - start_transformer_step_time) - noise_uncond = self._reshape_noise_prediction( - outputs, - batch_size, - post_patch_num_frames, - post_patch_height, - post_patch_width, - p_t, - p_h, - p_w, - ) + noise_uncond = _run_non_unified_step("uncond", inputs_aic2) noise_pred = noise_uncond + current_guidance_scale * (noise_pred - noise_uncond) latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0] @@ -788,6 +801,12 @@ def __call__( cache_threshold_high: Optional[float] = None, cache_threshold_low: Optional[float] = None, parallel_compile: bool = True, + use_magcache: bool = False, + magcache_thresh: float = 0.06, + magcache_K: int = 2, + magcache_retention_ratio: float = 0.4, + magcache_ratios: Optional[List[float]] = None, + magcache_verbose: bool = False, ): """ Generate videos from text prompts using the QEfficient-optimized WAN pipeline on QAIC hardware. @@ -834,6 +853,12 @@ def __call__( cache_threshold_low (float, optional): First-block-cache threshold for low-noise stage. Used only when `enable_first_block_cache=True`. parallel_compile (bool, optional): Whether to compile modules in parallel. Default: True. + use_magcache (bool, optional): Enable WAN runtime MagCache skip/reuse logic. Default: False. + magcache_thresh (float, optional): MagCache accumulated error threshold. Default: 0.06. + magcache_K (int, optional): Maximum number of consecutive skipped calls per stream. Default: 2. + magcache_retention_ratio (float, optional): Stage retention ratio in [0, 1]. Default: 0.4. + magcache_ratios (List[float], optional): Optional custom MagCache ratio profile. + magcache_verbose (bool, optional): Emit per-call MagCache decisions to logger. Default: False. Returns: QEffPipelineOutput: A dataclass containing: @@ -897,11 +922,34 @@ def __call__( if self.model.config.boundary_ratio is not None and guidance_scale_2 is None: guidance_scale_2 = guidance_scale + if isinstance(use_magcache, str): + lowered = use_magcache.strip().lower() + if lowered in {"1", "true", "yes", "on"}: + use_magcache = True + elif lowered in {"0", "false", "no", "off"}: + use_magcache = False + else: + raise ValueError( + f"Invalid string value for `use_magcache`: {use_magcache!r}. " + "Use one of: true/false, 1/0, yes/no, on/off." + ) + elif not isinstance(use_magcache, bool): + use_magcache = bool(use_magcache) + logger.warning(f"Coerced non-bool `use_magcache` to {use_magcache}.") + if not self.enable_first_block_cache and (cache_threshold_high is not None or cache_threshold_low is not None): logger.warning( "Ignoring cache thresholds because first-block-cache is disabled. " "Set `enable_first_block_cache=True` and `use_unified=False` to enable it." ) + if not use_magcache and ( + magcache_verbose + or magcache_ratios is not None + or magcache_thresh != 0.06 + or magcache_K != 2 + or magcache_retention_ratio != 0.4 + ): + logger.warning("Ignoring MagCache knobs because `use_magcache=False`.") # Initialize pipeline state self._guidance_scale = guidance_scale @@ -968,6 +1016,22 @@ def __call__( else: boundary_timestep = None + magcache_runtime = None + if use_magcache: + high_noise_steps = None + if boundary_timestep is not None: + high_noise_steps = int((timesteps >= boundary_timestep).sum().item()) + magcache_runtime = WanMagCacheRuntime( + num_inference_steps=num_inference_steps, + do_classifier_free_guidance=self.do_classifier_free_guidance, + threshold=magcache_thresh, + max_skip_steps=magcache_K, + retention_ratio=magcache_retention_ratio, + split_step=high_noise_steps, + ratios=magcache_ratios, + verbose=magcache_verbose, + ) + # Step 7: Initialize transformer sessions and buffers cl, _, _, _ = calculate_latent_dimensions_with_frames( height, @@ -996,6 +1060,7 @@ def __call__( callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs, cache_threshold_high=cache_threshold_high, cache_threshold_low=cache_threshold_low, + magcache_runtime=magcache_runtime, ) self._current_timestep = None diff --git a/examples/diffusers/wan/README.md b/examples/diffusers/wan/README.md index 6df291d802..1ab081bd66 100644 --- a/examples/diffusers/wan/README.md +++ b/examples/diffusers/wan/README.md @@ -10,6 +10,7 @@ WAN 2.2 is a text-to-video diffusion model that uses dual-stage processing for h - **`wan_lightning.py`** - Complete example with Lightning LoRA for fast video generation - **`wan_first_block_cache.py`** - Non-unified WAN with patch-based first-block-cache enabled +- **`wan_magcache.py`** - Non-unified WAN with runtime MagCache skip/reuse acceleration - **`wan_config.json`** - Contains default compilation config for transformer, vae modules. - **`wan_non_unified_config.json`** - Non-unified module config (`transformer_high`, `transformer_low`, `vae_decoder`) @@ -86,6 +87,11 @@ Run the Lightning example: python wan_lightning.py ``` +Run the MagCache example: +```bash +python wan_magcache.py +``` + ## Advanced Customization diff --git a/examples/diffusers/wan/wan_magcache.py b/examples/diffusers/wan/wan_magcache.py new file mode 100644 index 0000000000..18da3d6c37 --- /dev/null +++ b/examples/diffusers/wan/wan_magcache.py @@ -0,0 +1,43 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- +import torch +from diffusers.utils import export_to_video + +from QEfficient import QEffWanPipeline + +# Non-unified WAN + MagCache runtime acceleration (no ONNX/QPC signature changes). +pipeline = QEffWanPipeline.from_pretrained( + "Wan-AI/Wan2.2-T2V-A14B-Diffusers", + use_unified=False, +) + +prompt = "Two anthropomorphic cats in comfy boxing gear and bright gloves fight intensely on a spotlighted stage." +negative_prompt = "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走" + +output = pipeline( + prompt=prompt, + negative_prompt=negative_prompt, + num_frames=81, + guidance_scale=4.0, + guidance_scale_2=3.0, + num_inference_steps=40, + generator=torch.Generator().manual_seed(42), + custom_config_path="examples/diffusers/wan/wan_non_unified_config.json", + height=96, + width=160, + parallel_compile=True, + use_onnx_subfunctions=True, + use_magcache=True, + magcache_thresh=0.06, + magcache_K=2, + magcache_retention_ratio=0.4, + magcache_verbose=False, +) + +frames = output.images[0] +export_to_video(frames, "wan_magcache.mp4", fps=16) +print(output) diff --git a/tests/diffusers/test_wan_magcache.py b/tests/diffusers/test_wan_magcache.py new file mode 100644 index 0000000000..f24b4146e1 --- /dev/null +++ b/tests/diffusers/test_wan_magcache.py @@ -0,0 +1,96 @@ +# ----------------------------------------------------------------------------- +# +# Copyright (c) Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# +# ----------------------------------------------------------------------------- + +import numpy as np +import pytest +import torch + +from QEfficient.diffusers.pipelines.wan.magcache import WanMagCacheRuntime, nearest_interp + + +@pytest.mark.diffusers +def test_nearest_interp_target_length_one_uses_last_value(): + src = np.asarray([0.1, 0.2, 0.9], dtype=np.float32) + out = nearest_interp(src, 1) + assert out.shape == (1,) + assert np.isclose(out[0], src[-1]) + + +@pytest.mark.diffusers +def test_prepare_ratios_cfg_and_non_cfg_lengths(): + ratios = [0.99, 0.98, 0.97, 0.96] + + cfg_runtime = WanMagCacheRuntime( + num_inference_steps=5, + do_classifier_free_guidance=True, + threshold=0.1, + max_skip_steps=2, + retention_ratio=0.0, + split_step=3, + ratios=ratios, + ) + assert len(cfg_runtime._prepared_ratios) == 10 + + non_cfg_runtime = WanMagCacheRuntime( + num_inference_steps=5, + do_classifier_free_guidance=False, + threshold=0.1, + max_skip_steps=2, + retention_ratio=0.0, + split_step=3, + ratios=ratios, + ) + assert len(non_cfg_runtime._prepared_ratios) == 5 + + +@pytest.mark.diffusers +def test_stage_aware_retention_window_behavior(): + runtime = WanMagCacheRuntime( + num_inference_steps=5, + do_classifier_free_guidance=False, + threshold=0.1, + max_skip_steps=2, + retention_ratio=0.4, + split_step=3, + ratios=[1.0] * 5, + ) + + allowed = [runtime._cache_allowed_for_call(i) for i in range(5)] + assert allowed == [False, True, True, False, True] + + +@pytest.mark.diffusers +def test_skip_path_advances_call_index_and_respects_k_limit(): + runtime = WanMagCacheRuntime( + num_inference_steps=4, + do_classifier_free_guidance=False, + threshold=1.0, + max_skip_steps=2, + retention_ratio=0.0, + split_step=None, + ratios=[1.0] * 4, + ) + + assert runtime.should_skip("cond") is False + runtime.complete_call("cond", torch.zeros(1)) + assert runtime.call_index == 1 + + assert runtime.should_skip("cond") is True + runtime.complete_skip("cond") + assert runtime.call_index == 2 + + assert runtime.should_skip("cond") is True + runtime.complete_skip("cond") + assert runtime.call_index == 3 + + # Third consecutive skip exceeds K=2 and should force execution. + assert runtime.should_skip("cond") is False + runtime.complete_call("cond", torch.zeros(1)) + + # End of denoise sequence resets runtime state for next video. + assert runtime.call_index == 0 + assert runtime.stream_states["cond"].cached_residual is None