diff --git a/diffsynth/pipelines/wan_video.py b/diffsynth/pipelines/wan_video.py index bbc479e29..f347a8d24 100644 --- a/diffsynth/pipelines/wan_video.py +++ b/diffsynth/pipelines/wan_video.py @@ -66,7 +66,6 @@ def __init__(self, device=get_device_type(), torch_dtype=torch.bfloat16): WanVideoUnit_FunCameraControl(), WanVideoUnit_SpeedControl(), WanVideoUnit_VACE(), - WanVideoUnit_AnimateVideoSplit(), WanVideoUnit_AnimatePoseLatents(), WanVideoUnit_AnimateFacePixelValues(), WanVideoUnit_AnimateInpaint(), @@ -351,12 +350,15 @@ def process(self, pipe: WanVideoPipeline, height, width, num_frames): class WanVideoUnit_NoiseInitializer(PipelineUnit): def __init__(self): super().__init__( - input_params=("height", "width", "num_frames", "seed", "rand_device", "vace_reference_image"), + input_params=("height", "width", "num_frames", "seed", "rand_device", "vace_reference_image", "input_image", "animate_pose_video"), output_params=("noise",) ) - def process(self, pipe: WanVideoPipeline, height, width, num_frames, seed, rand_device, vace_reference_image): + def process(self, pipe: WanVideoPipeline, height, width, num_frames, seed, rand_device, vace_reference_image, input_image, animate_pose_video): length = (num_frames - 1) // 4 + 1 + # For wan-animate, input_image is a single reference frame; align time dimension. + if input_image is not None and animate_pose_video is not None: + length += 1 if vace_reference_image is not None: f = len(vace_reference_image) if isinstance(vace_reference_image, list) else 1 length += f @@ -371,12 +373,12 @@ def process(self, pipe: WanVideoPipeline, height, width, num_frames, seed, rand_ class WanVideoUnit_InputVideoEmbedder(PipelineUnit): def __init__(self): super().__init__( - input_params=("input_video", "noise", "tiled", "tile_size", "tile_stride", "vace_reference_image"), + input_params=("input_video", "noise", "tiled", "tile_size", "tile_stride", "vace_reference_image", "input_image", "animate_pose_video"), output_params=("latents", "input_latents"), onload_model_names=("vae",) ) - def process(self, pipe: WanVideoPipeline, input_video, noise, tiled, tile_size, tile_stride, vace_reference_image): + def process(self, pipe: WanVideoPipeline, input_video, noise, tiled, tile_size, tile_stride, vace_reference_image, input_image, animate_pose_video): if input_video is None: return {"latents": noise} pipe.load_models_to_device(self.onload_model_names) @@ -388,6 +390,11 @@ def process(self, pipe: WanVideoPipeline, input_video, noise, tiled, tile_size, vace_reference_image = pipe.preprocess_video(vace_reference_image) vace_reference_latents = pipe.vae.encode(vace_reference_image, device=pipe.device).to(dtype=pipe.torch_dtype, device=pipe.device) input_latents = torch.concat([vace_reference_latents, input_latents], dim=2) + # For wan-animate, prepend the single reference frame latent + if input_image is not None and animate_pose_video is not None: + input_image = pipe.preprocess_video([input_image]) + input_image_latents = pipe.vae.encode(input_image, device=pipe.device).to(dtype=pipe.torch_dtype, device=pipe.device) + input_latents = torch.concat([input_image_latents, input_latents], dim=2) if pipe.scheduler.training: return {"latents": noise, "input_latents": input_latents} else: @@ -903,27 +910,6 @@ def process(self, pipe: WanVideoPipeline, latents, motion_latents, drop_motion_f return {"latents": latents} -class WanVideoUnit_AnimateVideoSplit(PipelineUnit): - def __init__(self): - super().__init__( - input_params=("input_video", "animate_pose_video", "animate_face_video", "animate_inpaint_video", "animate_mask_video"), - output_params=("animate_pose_video", "animate_face_video", "animate_inpaint_video", "animate_mask_video") - ) - - def process(self, pipe: WanVideoPipeline, input_video, animate_pose_video, animate_face_video, animate_inpaint_video, animate_mask_video): - if input_video is None: - return {} - if animate_pose_video is not None: - animate_pose_video = animate_pose_video[:len(input_video) - 4] - if animate_face_video is not None: - animate_face_video = animate_face_video[:len(input_video) - 4] - if animate_inpaint_video is not None: - animate_inpaint_video = animate_inpaint_video[:len(input_video) - 4] - if animate_mask_video is not None: - animate_mask_video = animate_mask_video[:len(input_video) - 4] - return {"animate_pose_video": animate_pose_video, "animate_face_video": animate_face_video, "animate_inpaint_video": animate_inpaint_video, "animate_mask_video": animate_mask_video} - - class WanVideoUnit_AnimatePoseLatents(PipelineUnit): def __init__(self): super().__init__( diff --git a/examples/wanvideo/model_inference/Wan2.2-Animate-14B.py b/examples/wanvideo/model_inference/Wan2.2-Animate-14B.py index d435b688f..0dedeb7f3 100644 --- a/examples/wanvideo/model_inference/Wan2.2-Animate-14B.py +++ b/examples/wanvideo/model_inference/Wan2.2-Animate-14B.py @@ -26,15 +26,15 @@ # Animate input_image = Image.open("data/examples/wan/animate/animate_input_image.png") -animate_pose_video = VideoData("data/examples/wan/animate/animate_pose_video.mp4").raw_data()[:81-4] -animate_face_video = VideoData("data/examples/wan/animate/animate_face_video.mp4").raw_data()[:81-4] +animate_pose_video = VideoData("data/examples/wan/animate/animate_pose_video.mp4").raw_data()[:77] +animate_face_video = VideoData("data/examples/wan/animate/animate_face_video.mp4").raw_data()[:77] video = pipe( prompt="视频中的人在做动作", seed=0, tiled=True, input_image=input_image, animate_pose_video=animate_pose_video, animate_face_video=animate_face_video, - num_frames=81, height=720, width=1280, + num_frames=77, height=720, width=1280, num_inference_steps=20, cfg_scale=1, ) save_video(video, "video_1_Wan2.2-Animate-14B.mp4", fps=15, quality=5) @@ -44,10 +44,10 @@ lora_state_dict = load_state_dict("models/Wan-AI/Wan2.2-Animate-14B/relighting_lora.ckpt", torch_dtype=torch.bfloat16, device="cuda")["state_dict"] pipe.load_lora(pipe.dit, state_dict=lora_state_dict) input_image = Image.open("data/examples/wan/animate/replace_input_image.png") -animate_pose_video = VideoData("data/examples/wan/animate/replace_pose_video.mp4").raw_data()[:81-4] -animate_face_video = VideoData("data/examples/wan/animate/replace_face_video.mp4").raw_data()[:81-4] -animate_inpaint_video = VideoData("data/examples/wan/animate/replace_inpaint_video.mp4").raw_data()[:81-4] -animate_mask_video = VideoData("data/examples/wan/animate/replace_mask_video.mp4").raw_data()[:81-4] +animate_pose_video = VideoData("data/examples/wan/animate/replace_pose_video.mp4").raw_data()[:77] +animate_face_video = VideoData("data/examples/wan/animate/replace_face_video.mp4").raw_data()[:77] +animate_inpaint_video = VideoData("data/examples/wan/animate/replace_inpaint_video.mp4").raw_data()[:77] +animate_mask_video = VideoData("data/examples/wan/animate/replace_mask_video.mp4").raw_data()[:77] video = pipe( prompt="视频中的人在做动作", seed=0, tiled=True, @@ -56,7 +56,8 @@ animate_face_video=animate_face_video, animate_inpaint_video=animate_inpaint_video, animate_mask_video=animate_mask_video, - num_frames=81, height=720, width=1280, + num_frames=77, height=720, width=1280, num_inference_steps=20, cfg_scale=1, ) save_video(video, "video_2_Wan2.2-Animate-14B.mp4", fps=15, quality=5) + diff --git a/examples/wanvideo/model_inference_low_vram/Wan2.2-Animate-14B.py b/examples/wanvideo/model_inference_low_vram/Wan2.2-Animate-14B.py index 180482c14..fb6176ba2 100644 --- a/examples/wanvideo/model_inference_low_vram/Wan2.2-Animate-14B.py +++ b/examples/wanvideo/model_inference_low_vram/Wan2.2-Animate-14B.py @@ -37,15 +37,15 @@ # Animate input_image = Image.open("data/examples/wan/animate/animate_input_image.png") -animate_pose_video = VideoData("data/examples/wan/animate/animate_pose_video.mp4").raw_data()[:81-4] -animate_face_video = VideoData("data/examples/wan/animate/animate_face_video.mp4").raw_data()[:81-4] +animate_pose_video = VideoData("data/examples/wan/animate/animate_pose_video.mp4").raw_data()[:77] +animate_face_video = VideoData("data/examples/wan/animate/animate_face_video.mp4").raw_data()[:77] video = pipe( prompt="视频中的人在做动作", seed=0, tiled=True, input_image=input_image, animate_pose_video=animate_pose_video, animate_face_video=animate_face_video, - num_frames=81, height=720, width=1280, + num_frames=77, height=720, width=1280, num_inference_steps=20, cfg_scale=1, ) save_video(video, "video_1_Wan2.2-Animate-14B.mp4", fps=15, quality=5) @@ -56,10 +56,10 @@ lora_state_dict = {i: lora_state_dict[i].to(torch.bfloat16) for i in lora_state_dict} pipe.load_lora(pipe.dit, state_dict=lora_state_dict) input_image = Image.open("data/examples/wan/animate/replace_input_image.png") -animate_pose_video = VideoData("data/examples/wan/animate/replace_pose_video.mp4").raw_data()[:81-4] -animate_face_video = VideoData("data/examples/wan/animate/replace_face_video.mp4").raw_data()[:81-4] -animate_inpaint_video = VideoData("data/examples/wan/animate/replace_inpaint_video.mp4").raw_data()[:81-4] -animate_mask_video = VideoData("data/examples/wan/animate/replace_mask_video.mp4").raw_data()[:81-4] +animate_pose_video = VideoData("data/examples/wan/animate/replace_pose_video.mp4").raw_data()[:77] +animate_face_video = VideoData("data/examples/wan/animate/replace_face_video.mp4").raw_data()[:77] +animate_inpaint_video = VideoData("data/examples/wan/animate/replace_inpaint_video.mp4").raw_data()[:77] +animate_mask_video = VideoData("data/examples/wan/animate/replace_mask_video.mp4").raw_data()[:77] video = pipe( prompt="视频中的人在做动作", seed=0, tiled=True, @@ -68,7 +68,8 @@ animate_face_video=animate_face_video, animate_inpaint_video=animate_inpaint_video, animate_mask_video=animate_mask_video, - num_frames=81, height=720, width=1280, + num_frames=77, height=720, width=1280, num_inference_steps=20, cfg_scale=1, ) save_video(video, "video_2_Wan2.2-Animate-14B.mp4", fps=15, quality=5) + diff --git a/examples/wanvideo/model_training/validate_full/Wan2.2-Animate-14B.py b/examples/wanvideo/model_training/validate_full/Wan2.2-Animate-14B.py index 0cdce0656..2d965d45f 100644 --- a/examples/wanvideo/model_training/validate_full/Wan2.2-Animate-14B.py +++ b/examples/wanvideo/model_training/validate_full/Wan2.2-Animate-14B.py @@ -19,15 +19,16 @@ pipe.animate_adapter.load_state_dict(state_dict, strict=False) input_image = VideoData("data/example_video_dataset/animate/animate_output.mp4", height=480, width=832)[0] -animate_pose_video = VideoData("data/examples/wan/animate/animate_pose_video.mp4", height=480, width=832).raw_data()[:81-4] -animate_face_video = VideoData("data/examples/wan/animate/animate_face_video.mp4", height=512, width=512).raw_data()[:81-4] +animate_pose_video = VideoData("data/examples/wan/animate/animate_pose_video.mp4", height=480, width=832).raw_data()[:77] +animate_face_video = VideoData("data/examples/wan/animate/animate_face_video.mp4", height=512, width=512).raw_data()[:77] video = pipe( prompt="视频中的人在做动作", seed=0, tiled=True, input_image=input_image, animate_pose_video=animate_pose_video, animate_face_video=animate_face_video, - num_frames=81, height=480, width=832, + num_frames=77, height=480, width=832, num_inference_steps=20, cfg_scale=1, ) -save_video(video, "video_Wan2.2-Animate-14B.mp4", fps=15, quality=5) \ No newline at end of file +save_video(video, "video_Wan2.2-Animate-14B.mp4", fps=15, quality=5) + diff --git a/examples/wanvideo/model_training/validate_lora/Wan2.2-Animate-14B.py b/examples/wanvideo/model_training/validate_lora/Wan2.2-Animate-14B.py index 79326cd08..162ca0014 100644 --- a/examples/wanvideo/model_training/validate_lora/Wan2.2-Animate-14B.py +++ b/examples/wanvideo/model_training/validate_lora/Wan2.2-Animate-14B.py @@ -18,15 +18,16 @@ pipe.load_lora(pipe.dit, "models/train/Wan2.2-Animate-14B_lora/epoch-4.safetensors", alpha=1) input_image = VideoData("data/example_video_dataset/animate/animate_output.mp4", height=480, width=832)[0] -animate_pose_video = VideoData("data/examples/wan/animate/animate_pose_video.mp4", height=480, width=832).raw_data()[:81-4] -animate_face_video = VideoData("data/examples/wan/animate/animate_face_video.mp4", height=512, width=512).raw_data()[:81-4] +animate_pose_video = VideoData("data/examples/wan/animate/animate_pose_video.mp4", height=480, width=832).raw_data()[:77] +animate_face_video = VideoData("data/examples/wan/animate/animate_face_video.mp4", height=512, width=512).raw_data()[:77] video = pipe( prompt="视频中的人在做动作", seed=0, tiled=True, input_image=input_image, animate_pose_video=animate_pose_video, animate_face_video=animate_face_video, - num_frames=81, height=480, width=832, + num_frames=77, height=480, width=832, num_inference_steps=20, cfg_scale=1, ) -save_video(video, "video_Wan2.2-Animate-14B.mp4", fps=15, quality=5) \ No newline at end of file +save_video(video, "video_Wan2.2-Animate-14B.mp4", fps=15, quality=5) +