Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
155 changes: 147 additions & 8 deletions easy_samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,50 @@
from .nodes_registry import comfy_node


def _make_av_latent_dict(video_latent_dict, audio_tensor, audio_noise_mask=None):
"""Wrap video latent dict + audio tensor into AV latent dict with NestedTensor.

If audio_tensor is None, returns video_latent_dict unchanged.
Creates matching noise masks for both modalities when either is present.
"""
if audio_tensor is None:
return video_latent_dict
result = video_latent_dict.copy()
result["samples"] = NestedTensor([result["samples"], audio_tensor])
video_mask = result.get("noise_mask")
if video_mask is not None or audio_noise_mask is not None:
if video_mask is None:
vs = result["samples"].tensors[0]
video_mask = torch.ones(
vs.shape[0], 1, vs.shape[2], vs.shape[3], vs.shape[4],
device=vs.device, dtype=vs.dtype,
)
if audio_noise_mask is None:
audio_noise_mask = torch.ones(
audio_tensor.shape[0], 1, audio_tensor.shape[2], audio_tensor.shape[3],
device=audio_tensor.device, dtype=audio_tensor.dtype,
)
result["noise_mask"] = NestedTensor([video_mask, audio_noise_mask])
return result


def _split_av_latent_dict(latent_dict):
"""Split AV latent dict into (video_latent_dict, audio_tensor).

If the latent is not an AV NestedTensor, returns (latent_dict, None).
"""
samples = latent_dict["samples"]
if not isinstance(samples, NestedTensor) or len(samples.tensors) < 2:
return latent_dict, None
result = latent_dict.copy()
result["samples"] = samples.tensors[0]
audio = samples.tensors[1]
nm = result.get("noise_mask")
if nm is not None and isinstance(nm, NestedTensor):
result["noise_mask"] = nm.tensors[0]
return result, audio


def _get_raw_conds_from_guider(guider):
if not hasattr(guider, "raw_conds"):
if "negative" not in guider.original_conds:
Expand Down Expand Up @@ -148,6 +192,7 @@ def sample(
optional_initialization_latents=None,
guiding_start_step=0,
guiding_end_step=1000,
_audio_tile=None,
):
guider = copy.copy(guider)
guider.original_conds = copy.deepcopy(guider.original_conds)
Expand Down Expand Up @@ -262,13 +307,15 @@ def sample(

# Denoise the latent video
print("Denoising with conditioning on sigmas: ", middle_sigmas)
_av = _make_av_latent_dict(latents, _audio_tile)
(output_latents, denoised_output_latents) = SamplerCustomAdvanced().sample(
noise=noise,
guider=guider,
sampler=sampler,
sigmas=middle_sigmas,
latent_image=latents,
latent_image=_av,
)
denoised_output_latents, _audio_tile = _split_av_latent_dict(denoised_output_latents)

# Clean up guides if image conditioning was used
positive, negative, denoised_output_latents = LTXVCropGuides.execute(
Expand All @@ -284,13 +331,18 @@ def sample(
"Denoising with no conditioning but with classical i2v noise mask on sigmas: ",
low_sigmas,
)
_av = _make_av_latent_dict(denoised_output_latents, _audio_tile)
(_, denoised_output_latents) = SamplerCustomAdvanced().sample(
noise=noise,
guider=guider,
sampler=sampler,
sigmas=low_sigmas,
latent_image=denoised_output_latents,
latent_image=_av,
)
denoised_output_latents, _audio_tile = _split_av_latent_dict(denoised_output_latents)

if _audio_tile is not None:
denoised_output_latents["_audio"] = _audio_tile

return (denoised_output_latents, positive, negative)

Expand Down Expand Up @@ -399,6 +451,8 @@ def sample(
guiding_start_step=0,
guiding_end_step=1000,
normalize_per_frame=False,
_audio_tile=None,
_audio_new_init=None,
):
guider = copy.copy(guider)
guider.original_conds = copy.deepcopy(guider.original_conds)
Expand All @@ -412,7 +466,20 @@ def sample(

positive, negative = _get_raw_conds_from_guider(guider)

# Handle AV latents (standalone mode)
_standalone_av = False
_accumulated_audio = _audio_tile
samples = latents["samples"]
if isinstance(samples, NestedTensor) and len(samples.tensors) == 2:
if _accumulated_audio is None:
_accumulated_audio = samples.tensors[1]
_standalone_av = True
latents = latents.copy()
latents["samples"] = samples.tensors[0]
if "noise_mask" in latents and isinstance(latents["noise_mask"], NestedTensor):
latents["noise_mask"] = latents["noise_mask"].tensors[0]
samples = latents["samples"]

batch, channels, frames, height, width = samples.shape
time_scale_factor, width_scale_factor, height_scale_factor = (
vae.downscale_index_formula
Expand All @@ -428,6 +495,52 @@ def sample(
latents, -overlap, -1
)

# Set up audio extend tile if audio is available
_audio_extend_tile = None
_audio_noise_mask = None
_audio_overlap = 0
if _accumulated_audio is not None:
audio_T = _accumulated_audio.shape[2]
video_T = frames
audio_ratio = audio_T / max(video_T, 1)
_audio_overlap = max(1, round(overlap * audio_ratio))
video_new_latent_frames = num_new_frames // time_scale_factor
audio_new_frames = max(1, round(video_new_latent_frames * audio_ratio))

# Build audio tile: overlap (already denoised) + new frames.
# If _audio_new_init is provided (stage-2 refinement), use it
# as initialization for the new frames instead of zeros.
audio_overlap_data = _accumulated_audio[:, :, -_audio_overlap:]
if _audio_new_init is not None:
available = min(audio_new_frames, _audio_new_init.shape[2])
audio_new_data = _audio_new_init[:, :, :available].clone()
if available < audio_new_frames:
pad = torch.zeros(
_accumulated_audio.shape[0], _accumulated_audio.shape[1],
audio_new_frames - available, _accumulated_audio.shape[3],
device=_accumulated_audio.device, dtype=_accumulated_audio.dtype,
)
audio_new_data = torch.cat([audio_new_data, pad], dim=2)
else:
audio_new_data = torch.zeros(
_accumulated_audio.shape[0], _accumulated_audio.shape[1],
audio_new_frames, _accumulated_audio.shape[3],
device=_accumulated_audio.device, dtype=_accumulated_audio.dtype,
)
_audio_extend_tile = torch.cat([audio_overlap_data, audio_new_data], dim=2)

# Audio noise mask: preserve overlap, denoise new
_audio_noise_mask = torch.ones(
_audio_extend_tile.shape[0], 1,
_audio_extend_tile.shape[2], _audio_extend_tile.shape[3],
device=_audio_extend_tile.device, dtype=_audio_extend_tile.dtype,
)
_audio_noise_mask[:, :, :_audio_overlap] = 1.0 - strength
print(
f"[ExtendSampler] Audio extend tile: overlap={_audio_overlap}, "
f"new={audio_new_frames}, total={_audio_extend_tile.shape[2]}"
)

if optional_initialization_latents is None:
new_latents = EmptyLTXVLatentVideo.execute(
width=width * width_scale_factor,
Expand Down Expand Up @@ -488,13 +601,15 @@ def sample(
if len(high_sigmas) > 1:
guider.set_conds(positive, negative)
print("Denoising with overlap conditioning only on sigmas: ", high_sigmas)
_av = _make_av_latent_dict(new_latents, _audio_extend_tile, _audio_noise_mask)
(_, new_latents) = SamplerCustomAdvanced().sample(
noise=noise,
guider=guider,
sampler=sampler,
sigmas=high_sigmas,
latent_image=new_latents,
latent_image=_av,
)
new_latents, _audio_extend_tile = _split_av_latent_dict(new_latents)

if optional_guiding_latents is not None:
optional_guiding_latents = LTXVSelectLatents().select_latents(
Expand Down Expand Up @@ -533,13 +648,15 @@ def sample(

# Denoise the latent video
print("Denoising with full conditioning on sigmas: ", middle_sigmas)
_av = _make_av_latent_dict(new_latents, _audio_extend_tile, _audio_noise_mask)
(output_latents, denoised_output_latents) = SamplerCustomAdvanced().sample(
noise=noise,
guider=guider,
sampler=sampler,
sigmas=middle_sigmas,
latent_image=new_latents,
latent_image=_av,
)
denoised_output_latents, _audio_extend_tile = _split_av_latent_dict(denoised_output_latents)

positive, negative, denoised_output_latents = LTXVCropGuides.execute(
positive=positive,
Expand Down Expand Up @@ -591,13 +708,15 @@ def sample(
"Denoising with overlap + keyframes conditioning only on sigmas: ",
low_sigmas,
)
_av = _make_av_latent_dict(denoised_output_latents, _audio_extend_tile, _audio_noise_mask)
(_, denoised_output_latents) = SamplerCustomAdvanced().sample(
noise=noise,
guider=guider,
sampler=sampler,
sigmas=low_sigmas,
latent_image=denoised_output_latents,
latent_image=_av,
)
denoised_output_latents, _audio_extend_tile = _split_av_latent_dict(denoised_output_latents)
positive, negative, denoised_output_latents = LTXVCropGuides.execute(
positive=positive,
negative=negative,
Expand All @@ -621,6 +740,16 @@ def sample(
(latents,) = LinearOverlapLatentTransition().process(
latents, truncated_denoised_output_latents, overlap - 1, axis=2
)

# Accumulate audio: append new (non-overlap) audio frames
if _accumulated_audio is not None and _audio_extend_tile is not None:
new_audio = _audio_extend_tile[:, :, _audio_overlap:]
accumulated_audio_out = torch.cat([_accumulated_audio, new_audio], dim=2)
if _standalone_av:
latents["samples"] = NestedTensor([latents["samples"], accumulated_audio_out])
else:
latents["_audio"] = accumulated_audio_out

return (latents, positive, negative)


Expand Down Expand Up @@ -692,6 +821,7 @@ def sample(
guiding_strength=1.0,
guiding_start_step=0,
guiding_end_step=1000,
_audio_tile=None,
):
guider = copy.copy(guider)
guider.original_conds = copy.deepcopy(guider.original_conds)
Expand Down Expand Up @@ -735,13 +865,15 @@ def sample(
"Denoising with keyframes only [if available] on sigmas: ",
high_sigmas,
)
_av = _make_av_latent_dict(new_latents, _audio_tile)
(_, new_latents) = SamplerCustomAdvanced().sample(
noise=noise,
guider=guider,
sampler=sampler,
sigmas=high_sigmas,
latent_image=new_latents,
latent_image=_av,
)
new_latents, _audio_tile = _split_av_latent_dict(new_latents)

if optional_cond_indices is not None and 0 in optional_cond_indices:
guiding_latents = LTXVSelectLatents().select_latents(
Expand Down Expand Up @@ -806,13 +938,15 @@ def sample(

# Denoise the latent video
print("Denoising with full conditioning on sigmas: ", middle_sigmas)
_av = _make_av_latent_dict(new_latents, _audio_tile)
(_, denoised_output_latents) = SamplerCustomAdvanced().sample(
noise=noise,
guider=guider,
sampler=sampler,
sigmas=middle_sigmas,
latent_image=new_latents,
latent_image=_av,
)
denoised_output_latents, _audio_tile = _split_av_latent_dict(denoised_output_latents)

# Clean up guides if image conditioning was used
positive, negative, denoised_output_latents = LTXVCropGuides.execute(
Expand All @@ -827,19 +961,24 @@ def sample(
"Denoising with keyframes only [if available] conditioning on sigmas: ",
low_sigmas,
)
_av = _make_av_latent_dict(denoised_output_latents, _audio_tile)
(_, denoised_output_latents) = SamplerCustomAdvanced().sample(
noise=noise,
guider=guider,
sampler=sampler,
sigmas=low_sigmas,
latent_image=denoised_output_latents,
latent_image=_av,
)
denoised_output_latents, _audio_tile = _split_av_latent_dict(denoised_output_latents)
positive, negative, denoised_output_latents = LTXVCropGuides.execute(
positive=positive,
negative=negative,
latent=denoised_output_latents,
)

if _audio_tile is not None:
denoised_output_latents["_audio"] = _audio_tile

return (denoised_output_latents, positive, negative)


Expand Down
Loading