Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions src/diffusers/models/autoencoders/autoencoder_kl_mochi.py
Original file line number Diff line number Diff line change
Expand Up @@ -1092,7 +1092,7 @@ def forward(
sample_posterior: bool = False,
return_dict: bool = True,
generator: torch.Generator | None = None,
) -> torch.Tensor | torch.Tensor:
) -> DecoderOutput | torch.Tensor:
r"""
Args:
sample (`torch.Tensor`): Input sample.
Expand All @@ -1110,7 +1110,5 @@ def forward(
z = posterior.sample(generator=generator)
else:
z = posterior.mode()
dec = self.decode(z)
if not return_dict:
return (dec,)
dec = self.decode(z, return_dict=return_dict)
return dec
75 changes: 44 additions & 31 deletions tests/models/autoencoders/test_models_autoencoder_cosmos.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,24 +12,38 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest
import pytest
import torch

from diffusers import AutoencoderKLCosmos
from diffusers.utils.torch_utils import randn_tensor

from ...testing_utils import enable_full_determinism, floats_tensor, torch_device
from ..test_modeling_common import ModelTesterMixin
from .testing_utils import AutoencoderTesterMixin
from ...testing_utils import enable_full_determinism, torch_device
from ..testing_utils import BaseModelTesterConfig, MemoryTesterMixin, ModelTesterMixin, TrainingTesterMixin
from .testing_utils import NewAutoencoderTesterMixin


enable_full_determinism()


class AutoencoderKLCosmosTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase):
model_class = AutoencoderKLCosmos
main_input_name = "sample"
base_precision = 1e-2
class AutoencoderKLCosmosTesterConfig(BaseModelTesterConfig):
@property
def model_class(self):
return AutoencoderKLCosmos

@property
def main_input_name(self) -> str:
return "sample"

@property
def output_shape(self) -> tuple:
return (3, 9, 32, 32)

@property
def generator(self):
return torch.Generator("cpu").manual_seed(0)

def get_autoencoder_kl_cosmos_config(self):
def get_init_dict(self) -> dict:
return {
"in_channels": 3,
"out_channels": 3,
Expand All @@ -46,38 +60,37 @@ def get_autoencoder_kl_cosmos_config(self):
"temporal_compression_ratio": 4,
}

@property
def dummy_input(self):
def get_dummy_inputs(self) -> dict:
batch_size = 2
num_frames = 9
num_channels = 3
height = 32
width = 32

image = floats_tensor((batch_size, num_channels, num_frames, height, width)).to(torch_device)

image = randn_tensor(
(batch_size, num_channels, num_frames, height, width), generator=self.generator, device=torch_device
)
return {"sample": image}

@property
def input_shape(self):
return (3, 9, 32, 32)

@property
def output_shape(self):
return (3, 9, 32, 32)
class TestAutoencoderKLCosmos(AutoencoderKLCosmosTesterConfig, ModelTesterMixin):
base_precision = 1e-2


def prepare_init_args_and_inputs_for_common(self):
init_dict = self.get_autoencoder_kl_cosmos_config()
inputs_dict = self.dummy_input
return init_dict, inputs_dict
class TestAutoencoderKLCosmosTraining(AutoencoderKLCosmosTesterConfig, TrainingTesterMixin):
"""Training tests for AutoencoderKLCosmos."""

def test_gradient_checkpointing_is_applied(self):
expected_set = {
"CosmosEncoder3d",
"CosmosDecoder3d",
}
expected_set = {"CosmosEncoder3d", "CosmosDecoder3d"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)

@unittest.skip("Not sure why this test fails. Investigate later.")
def test_effective_gradient_checkpointing(self):
pass
@pytest.mark.skip("Not sure why this test fails. Investigate later.")
def test_gradient_checkpointing_equivalence(self):
super().test_gradient_checkpointing_equivalence()


class TestAutoencoderKLCosmosMemory(AutoencoderKLCosmosTesterConfig, MemoryTesterMixin):
"""Memory optimization tests for AutoencoderKLCosmos."""


class TestAutoencoderKLCosmosSlicingTiling(AutoencoderKLCosmosTesterConfig, NewAutoencoderTesterMixin):
"""Slicing and tiling tests for AutoencoderKLCosmos."""
66 changes: 38 additions & 28 deletions tests/models/autoencoders/test_models_autoencoder_kl_kvae.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,24 +13,37 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest
import torch

from diffusers import AutoencoderKLKVAE
from diffusers.utils.torch_utils import randn_tensor

from ...testing_utils import enable_full_determinism, floats_tensor, torch_device
from ..test_modeling_common import ModelTesterMixin
from .testing_utils import AutoencoderTesterMixin
from ...testing_utils import enable_full_determinism, torch_device
from ..testing_utils import BaseModelTesterConfig, MemoryTesterMixin, ModelTesterMixin, TrainingTesterMixin
from .testing_utils import NewAutoencoderTesterMixin


enable_full_determinism()


class AutoencoderKLKVAETests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase):
model_class = AutoencoderKLKVAE
main_input_name = "sample"
base_precision = 1e-2
class AutoencoderKLKVAETesterConfig(BaseModelTesterConfig):
@property
def model_class(self):
return AutoencoderKLKVAE

@property
def main_input_name(self) -> str:
return "sample"

@property
def output_shape(self) -> tuple:
return (3, 32, 32)

@property
def generator(self):
return torch.Generator("cpu").manual_seed(0)

def get_autoencoder_kl_kvae_config(self):
def get_init_dict(self) -> dict:
return {
"in_channels": 3,
"channels": 32,
Expand All @@ -42,32 +55,29 @@ def get_autoencoder_kl_kvae_config(self):
"sample_size": 32,
}

@property
def dummy_input(self):
def get_dummy_inputs(self) -> dict:
batch_size = 2
num_channels = 3
sizes = (32, 32)

image = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)

image = randn_tensor((batch_size, num_channels, *sizes), generator=self.generator, device=torch_device)
return {"sample": image}

@property
def input_shape(self):
return (3, 32, 32)

@property
def output_shape(self):
return (3, 32, 32)
class TestAutoencoderKLKVAE(AutoencoderKLKVAETesterConfig, ModelTesterMixin):
base_precision = 1e-2


def prepare_init_args_and_inputs_for_common(self):
init_dict = self.get_autoencoder_kl_kvae_config()
inputs_dict = self.dummy_input
return init_dict, inputs_dict
class TestAutoencoderKLKVAETraining(AutoencoderKLKVAETesterConfig, TrainingTesterMixin):
"""Training tests for AutoencoderKLKVAE."""

def test_gradient_checkpointing_is_applied(self):
expected_set = {
"KVAEEncoder2D",
"KVAEDecoder2D",
}
expected_set = {"KVAEEncoder2D", "KVAEDecoder2D"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)


class TestAutoencoderKLKVAEMemory(AutoencoderKLKVAETesterConfig, MemoryTesterMixin):
"""Memory optimization tests for AutoencoderKLKVAE."""


class TestAutoencoderKLKVAESlicingTiling(AutoencoderKLKVAETesterConfig, NewAutoencoderTesterMixin):
"""Slicing and tiling tests for AutoencoderKLKVAE."""
Original file line number Diff line number Diff line change
Expand Up @@ -13,58 +13,72 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest
import torch

from diffusers import AutoencoderKLTemporalDecoder
from diffusers.utils.torch_utils import randn_tensor

from ...testing_utils import (
enable_full_determinism,
floats_tensor,
torch_device,
)
from ..test_modeling_common import ModelTesterMixin
from .testing_utils import AutoencoderTesterMixin
from ...testing_utils import enable_full_determinism, torch_device
from ..testing_utils import BaseModelTesterConfig, MemoryTesterMixin, ModelTesterMixin, TrainingTesterMixin
from .testing_utils import NewAutoencoderTesterMixin


enable_full_determinism()


class AutoencoderKLTemporalDecoderTests(ModelTesterMixin, AutoencoderTesterMixin, unittest.TestCase):
model_class = AutoencoderKLTemporalDecoder
main_input_name = "sample"
base_precision = 1e-2

class AutoencoderKLTemporalDecoderTesterConfig(BaseModelTesterConfig):
@property
def dummy_input(self):
batch_size = 3
num_channels = 3
sizes = (32, 32)

image = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
num_frames = 3
def model_class(self):
return AutoencoderKLTemporalDecoder

return {"sample": image, "num_frames": num_frames}
@property
def main_input_name(self) -> str:
return "sample"

@property
def input_shape(self):
def output_shape(self) -> tuple:
return (3, 32, 32)

@property
def output_shape(self):
return (3, 32, 32)
def generator(self):
return torch.Generator("cpu").manual_seed(0)

def prepare_init_args_and_inputs_for_common(self):
init_dict = {
def get_init_dict(self) -> dict:
return {
"block_out_channels": [32, 64],
"in_channels": 3,
"out_channels": 3,
"down_block_types": ["DownEncoderBlock2D", "DownEncoderBlock2D"],
"latent_channels": 4,
"layers_per_block": 2,
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict

def get_dummy_inputs(self) -> dict:
batch_size = 3
num_channels = 3
sizes = (32, 32)
image = randn_tensor((batch_size, num_channels, *sizes), generator=self.generator, device=torch_device)
num_frames = 3
return {"sample": image, "num_frames": num_frames}


class TestAutoencoderKLTemporalDecoder(AutoencoderKLTemporalDecoderTesterConfig, ModelTesterMixin):
base_precision = 1e-2


class TestAutoencoderKLTemporalDecoderTraining(AutoencoderKLTemporalDecoderTesterConfig, TrainingTesterMixin):
"""Training tests for AutoencoderKLTemporalDecoder."""

def test_gradient_checkpointing_is_applied(self):
expected_set = {"Encoder", "TemporalDecoder", "UNetMidBlock2D"}
super().test_gradient_checkpointing_is_applied(expected_set=expected_set)


class TestAutoencoderKLTemporalDecoderMemory(AutoencoderKLTemporalDecoderTesterConfig, MemoryTesterMixin):
"""Memory optimization tests for AutoencoderKLTemporalDecoder."""


class TestAutoencoderKLTemporalDecoderSlicingTiling(
AutoencoderKLTemporalDecoderTesterConfig, NewAutoencoderTesterMixin
):
"""Slicing and tiling tests for AutoencoderKLTemporalDecoder."""
Loading
Loading