Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 69 additions & 57 deletions tests/models/autoencoders/test_models_autoencoder_hunyuan_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,27 +13,39 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import pytest
import torch

from diffusers import AutoencoderKLHunyuanVideo
from diffusers.models.autoencoders.autoencoder_kl_hunyuan_video import prepare_causal_attention_mask
from diffusers.utils.torch_utils import randn_tensor

from ...testing_utils import enable_full_determinism, floats_tensor, torch_device
from ..test_modeling_common import ModelTesterMixin
from .testing_utils import AutoencoderTesterMixin
from ...testing_utils import enable_full_determinism, torch_device
from ..testing_utils import BaseModelTesterConfig, MemoryTesterMixin, ModelTesterMixin, TrainingTesterMixin
from .testing_utils import NewAutoencoderTesterMixin


enable_full_determinism()


class AutoencoderKLHunyuanVideoTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase):
model_class = AutoencoderKLHunyuanVideo
main_input_name = "sample"
base_precision = 1e-2
class AutoencoderKLHunyuanVideoTesterConfig(BaseModelTesterConfig):
@property
def model_class(self):
return AutoencoderKLHunyuanVideo

@property
def main_input_name(self) -> str:
return "sample"

@property
def output_shape(self) -> tuple:
return (3, 9, 16, 16)

@property
def generator(self):
return torch.Generator("cpu").manual_seed(0)

def get_autoencoder_kl_hunyuan_video_config(self):
def get_init_dict(self) -> dict:
return {
"in_channels": 3,
"out_channels": 3,
Expand All @@ -60,29 +72,49 @@ def get_autoencoder_kl_hunyuan_video_config(self):
"mid_block_add_attention": True,
}

@property
def dummy_input(self):
def get_dummy_inputs(self) -> dict:
batch_size = 2
num_frames = 9
num_channels = 3
sizes = (16, 16)
image = randn_tensor(
(batch_size, num_channels, num_frames, *sizes), generator=self.generator, device=torch_device
)
return {"sample": image}

image = floats_tensor((batch_size, num_channels, num_frames) + sizes).to(torch_device)

return {"sample": image}
class TestAutoencoderKLHunyuanVideo(AutoencoderKLHunyuanVideoTesterConfig, ModelTesterMixin):
base_precision = 1e-2

@property
def input_shape(self):
return (3, 9, 16, 16)
@pytest.mark.skip("Unsupported test.")
def test_outputs_equivalence(self):
super().test_outputs_equivalence()

@property
def output_shape(self):
return (3, 9, 16, 16)
def test_prepare_causal_attention_mask(self):
def prepare_causal_attention_mask_orig(
num_frames: int, height_width: int, dtype: torch.dtype, device: torch.device, batch_size: int = None
) -> torch.Tensor:
seq_len = num_frames * height_width
mask = torch.full((seq_len, seq_len), float("-inf"), dtype=dtype, device=device)
for i in range(seq_len):
i_frame = i // height_width
mask[i, : (i_frame + 1) * height_width] = 0
if batch_size is not None:
mask = mask.unsqueeze(0).expand(batch_size, -1, -1)
return mask

# test with some odd shapes
original_mask = prepare_causal_attention_mask_orig(
num_frames=31, height_width=111, dtype=torch.float32, device=torch_device
)
new_mask = prepare_causal_attention_mask(
num_frames=31, height_width=111, dtype=torch.float32, device=torch_device
)
assert torch.allclose(original_mask, new_mask), "Causal attention mask should be the same"

def prepare_init_args_and_inputs_for_common(self):
init_dict = self.get_autoencoder_kl_hunyuan_video_config()
inputs_dict = self.dummy_input
return init_dict, inputs_dict

class TestAutoencoderKLHunyuanVideoTraining(AutoencoderKLHunyuanVideoTesterConfig, TrainingTesterMixin):
"""Training tests for AutoencoderKLHunyuanVideo."""

def test_gradient_checkpointing_is_applied(self):
expected_set = {
Expand All @@ -94,9 +126,18 @@ def test_gradient_checkpointing_is_applied(self):
}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)

# We need to overwrite this test because the base test does not account length of down_block_types

class TestAutoencoderKLHunyuanVideoMemory(AutoencoderKLHunyuanVideoTesterConfig, MemoryTesterMixin):
"""Memory optimization tests for AutoencoderKLHunyuanVideo."""


class TestAutoencoderKLHunyuanVideoSlicingTiling(AutoencoderKLHunyuanVideoTesterConfig, NewAutoencoderTesterMixin):
"""Slicing and tiling tests for AutoencoderKLHunyuanVideo."""

# Overwritten because the base test's block_out_channels doesn't account for the length of down_block_types.
def test_forward_with_norm_groups(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
init_dict = self.get_init_dict()
inputs_dict = self.get_dummy_inputs()

init_dict["norm_num_groups"] = 16
init_dict["block_out_channels"] = (16, 16, 16, 16)
Expand All @@ -111,35 +152,6 @@ def test_forward_with_norm_groups(self):
if isinstance(output, dict):
output = output.to_tuple()[0]

self.assertIsNotNone(output)
assert output is not None
expected_shape = inputs_dict["sample"].shape
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")

@unittest.skip("Unsupported test.")
def test_outputs_equivalence(self):
pass

def test_prepare_causal_attention_mask(self):
def prepare_causal_attention_mask_orig(
num_frames: int, height_width: int, dtype: torch.dtype, device: torch.device, batch_size: int = None
) -> torch.Tensor:
seq_len = num_frames * height_width
mask = torch.full((seq_len, seq_len), float("-inf"), dtype=dtype, device=device)
for i in range(seq_len):
i_frame = i // height_width
mask[i, : (i_frame + 1) * height_width] = 0
if batch_size is not None:
mask = mask.unsqueeze(0).expand(batch_size, -1, -1)
return mask

# test with some odd shapes
original_mask = prepare_causal_attention_mask_orig(
num_frames=31, height_width=111, dtype=torch.float32, device=torch_device
)
new_mask = prepare_causal_attention_mask(
num_frames=31, height_width=111, dtype=torch.float32, device=torch_device
)
self.assertTrue(
torch.allclose(original_mask, new_mask),
"Causal attention mask should be the same",
)
assert output.shape == expected_shape, "Input and output shapes do not match"
Loading