diff --git a/tests/models/autoencoders/test_models_autoencoder_hunyuan_video.py b/tests/models/autoencoders/test_models_autoencoder_hunyuan_video.py index 9813772a7c55..31ff9b3861ba 100644 --- a/tests/models/autoencoders/test_models_autoencoder_hunyuan_video.py +++ b/tests/models/autoencoders/test_models_autoencoder_hunyuan_video.py @@ -13,27 +13,39 @@ # See the License for the specific language governing permissions and # limitations under the License. -import unittest - +import pytest import torch from diffusers import AutoencoderKLHunyuanVideo from diffusers.models.autoencoders.autoencoder_kl_hunyuan_video import prepare_causal_attention_mask +from diffusers.utils.torch_utils import randn_tensor -from ...testing_utils import enable_full_determinism, floats_tensor, torch_device -from ..test_modeling_common import ModelTesterMixin -from .testing_utils import AutoencoderTesterMixin +from ...testing_utils import enable_full_determinism, torch_device +from ..testing_utils import BaseModelTesterConfig, MemoryTesterMixin, ModelTesterMixin, TrainingTesterMixin +from .testing_utils import NewAutoencoderTesterMixin enable_full_determinism() -class AutoencoderKLHunyuanVideoTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase): - model_class = AutoencoderKLHunyuanVideo - main_input_name = "sample" - base_precision = 1e-2 +class AutoencoderKLHunyuanVideoTesterConfig(BaseModelTesterConfig): + @property + def model_class(self): + return AutoencoderKLHunyuanVideo + + @property + def main_input_name(self) -> str: + return "sample" + + @property + def output_shape(self) -> tuple: + return (3, 9, 16, 16) + + @property + def generator(self): + return torch.Generator("cpu").manual_seed(0) - def get_autoencoder_kl_hunyuan_video_config(self): + def get_init_dict(self) -> dict: return { "in_channels": 3, "out_channels": 3, @@ -60,29 +72,49 @@ def get_autoencoder_kl_hunyuan_video_config(self): "mid_block_add_attention": True, } - @property - def dummy_input(self): + def get_dummy_inputs(self) -> dict: batch_size = 2 num_frames = 9 num_channels = 3 sizes = (16, 16) + image = randn_tensor( + (batch_size, num_channels, num_frames, *sizes), generator=self.generator, device=torch_device + ) + return {"sample": image} - image = floats_tensor((batch_size, num_channels, num_frames) + sizes).to(torch_device) - return {"sample": image} +class TestAutoencoderKLHunyuanVideo(AutoencoderKLHunyuanVideoTesterConfig, ModelTesterMixin): + base_precision = 1e-2 - @property - def input_shape(self): - return (3, 9, 16, 16) + @pytest.mark.skip("Unsupported test.") + def test_outputs_equivalence(self): + super().test_outputs_equivalence() - @property - def output_shape(self): - return (3, 9, 16, 16) + def test_prepare_causal_attention_mask(self): + def prepare_causal_attention_mask_orig( + num_frames: int, height_width: int, dtype: torch.dtype, device: torch.device, batch_size: int = None + ) -> torch.Tensor: + seq_len = num_frames * height_width + mask = torch.full((seq_len, seq_len), float("-inf"), dtype=dtype, device=device) + for i in range(seq_len): + i_frame = i // height_width + mask[i, : (i_frame + 1) * height_width] = 0 + if batch_size is not None: + mask = mask.unsqueeze(0).expand(batch_size, -1, -1) + return mask + + # test with some odd shapes + original_mask = prepare_causal_attention_mask_orig( + num_frames=31, height_width=111, dtype=torch.float32, device=torch_device + ) + new_mask = prepare_causal_attention_mask( + num_frames=31, height_width=111, dtype=torch.float32, device=torch_device + ) + assert torch.allclose(original_mask, new_mask), "Causal attention mask should be the same" - def prepare_init_args_and_inputs_for_common(self): - init_dict = self.get_autoencoder_kl_hunyuan_video_config() - inputs_dict = self.dummy_input - return init_dict, inputs_dict + +class TestAutoencoderKLHunyuanVideoTraining(AutoencoderKLHunyuanVideoTesterConfig, TrainingTesterMixin): + """Training tests for AutoencoderKLHunyuanVideo.""" def test_gradient_checkpointing_is_applied(self): expected_set = { @@ -94,9 +126,18 @@ def test_gradient_checkpointing_is_applied(self): } super().test_gradient_checkpointing_is_applied(expected_set=expected_set) - # We need to overwrite this test because the base test does not account length of down_block_types + +class TestAutoencoderKLHunyuanVideoMemory(AutoencoderKLHunyuanVideoTesterConfig, MemoryTesterMixin): + """Memory optimization tests for AutoencoderKLHunyuanVideo.""" + + +class TestAutoencoderKLHunyuanVideoSlicingTiling(AutoencoderKLHunyuanVideoTesterConfig, NewAutoencoderTesterMixin): + """Slicing and tiling tests for AutoencoderKLHunyuanVideo.""" + + # Overwritten because the base test's block_out_channels doesn't account for the length of down_block_types. def test_forward_with_norm_groups(self): - init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + init_dict = self.get_init_dict() + inputs_dict = self.get_dummy_inputs() init_dict["norm_num_groups"] = 16 init_dict["block_out_channels"] = (16, 16, 16, 16) @@ -111,35 +152,6 @@ def test_forward_with_norm_groups(self): if isinstance(output, dict): output = output.to_tuple()[0] - self.assertIsNotNone(output) + assert output is not None expected_shape = inputs_dict["sample"].shape - self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match") - - @unittest.skip("Unsupported test.") - def test_outputs_equivalence(self): - pass - - def test_prepare_causal_attention_mask(self): - def prepare_causal_attention_mask_orig( - num_frames: int, height_width: int, dtype: torch.dtype, device: torch.device, batch_size: int = None - ) -> torch.Tensor: - seq_len = num_frames * height_width - mask = torch.full((seq_len, seq_len), float("-inf"), dtype=dtype, device=device) - for i in range(seq_len): - i_frame = i // height_width - mask[i, : (i_frame + 1) * height_width] = 0 - if batch_size is not None: - mask = mask.unsqueeze(0).expand(batch_size, -1, -1) - return mask - - # test with some odd shapes - original_mask = prepare_causal_attention_mask_orig( - num_frames=31, height_width=111, dtype=torch.float32, device=torch_device - ) - new_mask = prepare_causal_attention_mask( - num_frames=31, height_width=111, dtype=torch.float32, device=torch_device - ) - self.assertTrue( - torch.allclose(original_mask, new_mask), - "Causal attention mask should be the same", - ) + assert output.shape == expected_shape, "Input and output shapes do not match"