The CIFAR-10 dataset is a dataset of 32x32 color images, from 10 classes (airplane, automobile, bird, cat, deer, dog, frog, horse, ship, truck). We'll train a diffusion model on this dataset, and generate samples from the model.
We'll start by using the same model architecture as in the previous chapter, and train for 80 epochs.
class Model(torch.nn.Module):
def __init__(self, num_steps=1000, embed_dim=64):
super(Model, self).__init__()
self.embed = torch.nn.Sequential(
PositionalEncoding(embed_dim, num_steps),
torch.nn.Linear(embed_dim, embed_dim),
torch.nn.ReLU(),
torch.nn.Linear(embed_dim, embed_dim),
torch.nn.ReLU(),
)
self.conv_in = torch.nn.Conv2d(3, 16, kernel_size=3, padding=1)
self.enc1_1 = ResnetBlock(16, 16, embed_dim)
self.enc1_2 = ResnetBlock(16, 32, embed_dim)
self.downconv1 = Downsample(32, 32)
self.enc2_1 = ResnetBlock(32, 32, embed_dim)
self.enc2_2 = ResnetBlock(32, 64, embed_dim)
self.downconv2 = Downsample(64, 64)
self.bottleneck_1 = ResnetBlock(64, 64, embed_dim)
self.bottleneck_2 = ResnetBlock(64, 64, embed_dim)
self.upconv2 = Upsample(64, 64)
self.dec2_1 = ResnetBlock(128, 64, embed_dim)
self.dec2_2 = ResnetBlock(64, 32, embed_dim)
self.upconv1 = Upsample(32, 32)
self.dec1_1 = ResnetBlock(64, 32, embed_dim)
self.dec1_2 = ResnetBlock(32, 16, embed_dim)
self.norm_out = torch.nn.GroupNorm(16, 16)
self.conv_out = torch.nn.Conv2d(16, 3, kernel_size=3, padding=1)
def forward(self, x, t):
emb = self.embed(t)
x = self.conv_in(x)
x = self.enc1_1(x, emb)
enc1 = self.enc1_2(x, emb)
x = self.downconv1(enc1)
x = self.enc2_1(x, emb)
enc2 = self.enc2_2(x, emb)
x = self.downconv2(enc2)
x = self.bottleneck_1(x, emb)
x = self.bottleneck_2(x, emb)
x = self.upconv2(x)
x = torch.cat([x, enc2], 1)
x = self.dec2_1(x, emb)
x = self.dec2_2(x, emb)
x = self.upconv1(x)
x = torch.cat([x, enc1], 1)
x = self.dec1_1(x, emb)
x = self.dec1_2(x, emb)
x = self.norm_out(x)
x = torch.nn.functional.relu(x)
x = self.conv_out(x)
return xWe can run this with:
python part_a_cifar.py trainYou can also change additional hyperparameters, such as the learning rate, batch size, and number of epochs:
python part_a_cifar.py train --batch-size 128 --lr 1e-3 --epochs 80After training, we can sample a grid of images from the model with:
python part_a_cifar.py testOur resulting output looks like this:
By increasing the number of channels in our model, we can improve the quality of the generated images. Let's double the number of channels in our models from 16 to 32.
class Model(torch.nn.Module):
def __init__(self, num_steps=1000, ch=32):
super(Model, self).__init__()
embed_dim = ch*4
self.embed = torch.nn.Sequential(
PositionalEncoding(ch, num_steps),
torch.nn.Linear(ch, embed_dim),
torch.nn.ReLU(),
torch.nn.Linear(embed_dim, embed_dim),
torch.nn.ReLU(),
)
self.conv_in = torch.nn.Conv2d(3, ch, kernel_size=3, padding=1)
self.enc1_1 = ResnetBlock(ch, ch, embed_dim)
self.enc1_2 = ResnetBlock(ch, ch*2, embed_dim)
self.downconv1 = Downsample(ch*2, ch*2)
self.enc2_1 = ResnetBlock(ch*2, ch*2, embed_dim)
self.enc2_2 = ResnetBlock(ch*2, ch*4, embed_dim)
self.downconv2 = Downsample(ch*4, ch*4)
self.bottleneck_1 = ResnetBlock(ch*4, ch*4, embed_dim)
self.bottleneck_2 = ResnetBlock(ch*4, ch*4, embed_dim)
self.upconv2 = Upsample(ch*4, ch*4)
self.dec2_1 = ResnetBlock(ch*8, ch*4, embed_dim)
self.dec2_2 = ResnetBlock(ch*4, ch*2, embed_dim)
self.upconv1 = Upsample(ch*2, ch*2)
self.dec1_1 = ResnetBlock(ch*4, ch*2, embed_dim)
self.dec1_2 = ResnetBlock(ch*2, ch, embed_dim)
self.norm_out = torch.nn.GroupNorm(32, ch)
self.conv_out = torch.nn.Conv2d(ch, 3, kernel_size=3, padding=1)
def forward(self, x, t):
emb = self.embed(t)
x = self.conv_in(x)
x = self.enc1_1(x, emb)
enc1 = self.enc1_2(x, emb)
x = self.downconv1(enc1)
x = self.enc2_1(x, emb)
enc2 = self.enc2_2(x, emb)
x = self.downconv2(enc2)
x = self.bottleneck_1(x, emb)
x = self.bottleneck_2(x, emb)
x = self.upconv2(x)
x = torch.cat([x, enc2], 1)
x = self.dec2_1(x, emb)
x = self.dec2_2(x, emb)
x = self.upconv1(x)
x = torch.cat([x, enc1], 1)
x = self.dec1_1(x, emb)
x = self.dec1_2(x, emb)
x = self.norm_out(x)
x = torch.nn.functional.silu(x)
x = self.conv_out(x)
return xLet's train this model now:
python part_b_cifar_more_channels.py trainYou can also change additional hyperparameters, such as the learning rate, batch size, number of epochs, and number of channels:
python part_b_cifar_more_channels.py train --batch-size 128 --lr 1e-3 --epochs 120 --model-channels 32After training, we can sample a grid of images from the model with:
python part_b_cifar_more_channels.py test --model-channels 32Our resulting output looks like this:
We can improve our model by replacing the ReLU activation function with the SiLU activation function:
This is the same activation function used in the DDPM paper.
We simply replace torch.nn.ReLU() with torch.nn.SiLU() and torch.nn.functional.relu() with torch.nn.functional.silu().
Let's train this model now:
python part_c_cifar_silu.py trainYou can also change additional hyperparameters, such as the learning rate, batch size, number of epochs, number of channels, and the activation function (relu, silu, leakyrelu, gelu):
python part_c_cifar_silu.py train --batch-size 128 --lr 1e-3 --epochs 120 --model-channels 32 --activation siluAfter training, we can sample a grid of images from the model with:
python part_c_cifar_silu.py test --model-channels 32 --activation siluOur resulting output looks like this:
We'll now refactor our architecture to be more like the one used in the DDPM paper.
class Model(torch.nn.Module):
def __init__(self, num_steps=1000, ch=32):
super(Model, self).__init__()
embed_dim = ch * 4
self.embed = torch.nn.Sequential(
PositionalEncoding(ch, num_steps),
torch.nn.Linear(ch, embed_dim),
torch.nn.SiLU(),
torch.nn.Linear(embed_dim, embed_dim),
torch.nn.SiLU(),
)
self.input_blocks = torch.nn.ModuleList([
torch.nn.Conv2d(3, ch, kernel_size=3, padding=1),
ResnetBlock(ch, ch, embed_dim),
ResnetBlock(ch, ch, embed_dim),
Downsample(ch, ch),
ResnetBlock(ch, ch*2, embed_dim),
ResnetBlock(ch*2, ch*2, embed_dim),
Downsample(ch*2, ch*2),
ResnetBlock(ch*2, ch*2, embed_dim),
ResnetBlock(ch*2, ch*2, embed_dim),
])
self.middle_block = TimestepBlockSequential(
ResnetBlock(ch*2, ch*2, embed_dim),
ResnetBlock(ch*2, ch*2, embed_dim),
)
self.output_blocks = torch.nn.ModuleList([
ResnetBlock(ch*4, ch*2, embed_dim),
ResnetBlock(ch*4, ch*2, embed_dim),
TimestepBlockSequential(
ResnetBlock(ch*4, ch*2, embed_dim),
Upsample(ch*2, ch*2),
),
ResnetBlock(ch*4, ch*2, embed_dim),
ResnetBlock(ch*4, ch*2, embed_dim),
TimestepBlockSequential(
ResnetBlock(ch*3, ch*2, embed_dim),
Upsample(ch*2, ch*2),
),
ResnetBlock(ch*3, ch, embed_dim),
ResnetBlock(ch*2, ch, embed_dim),
ResnetBlock(ch*2, ch, embed_dim),
])
self.out = torch.nn.Sequential(
torch.nn.GroupNorm(32, ch),
torch.nn.SiLU(),
torch.nn.Conv2d(ch, 3, kernel_size=3, padding=1)
)
def forward(self, x, t):
emb = self.embed(t)
hs = []
for module in self.input_blocks:
if isinstance(module, TimestepBlock):
x = module(x, emb)
else:
x = module(x)
hs.append(x)
x = self.middle_block(x, emb)
for module in self.output_blocks:
x = torch.cat([x, hs.pop()], 1)
x = module(x, emb)
return self.out(x)Each block is now a separate module, and we use a TimestepBlockSequential to apply a sequence of blocks to the input.
Instead of just the output of each resolution being concatenated, we concatenate the output of each input block with the corresponding output block.
Let's train this model now:
python part_d_cifar_arch.py trainYou can also change additional hyperparameters, such as the learning rate, batch size, and number of epochs:
python part_d_cifar_arch.py train --batch-size 128 --lr 1e-3 --epochs 120 --model-channels 32 --activation siluAfter training, we can sample a grid of images from the model with:
python part_d_cifar_arch.py test --model-channels 32 --activation siluOur resulting output looks like this:
Let's increase the number of downscales, from 2 to 3. We'll change the way we initialize the model, to make it easier to change the number of downscales.
We'll now use num_res_blocks (defaullt 2) to control the number of residual blocks at each level, and channel_mult (default (2, 2, 2, 1)) to control the multiplier for the number of channels at each level.
self.input_blocks = torch.nn.ModuleList([torch.nn.Conv2d(image_channels, model_channels, kernel_size=3, padding=1)])
channels = [model_channels]
for level, mult in enumerate(channel_mult):
out_ch = model_channels * mult
for _ in range(num_res_blocks):
in_ch = channels[-1]
self.input_blocks.append(ResnetBlock(in_ch, out_ch, embed_dim, activation_fn))
channels.append(out_ch)
if level < len(channel_mult) - 1:
self.input_blocks.append(Downsample(out_ch, out_ch))
channels.append(out_ch)
self.middle_block = TimestepBlockSequential()
out_ch = model_channels * channel_mult[-1]
for _ in range(num_res_blocks):
self.middle_block.append(ResnetBlock(out_ch, out_ch, embed_dim, activation_fn))
self.output_blocks = torch.nn.ModuleList()
for level, mult in enumerate(reversed(channel_mult)):
for i in range(num_res_blocks + 1):
in_ch = out_ch + channels.pop()
out_ch = model_channels * mult
block = TimestepBlockSequential(ResnetBlock(in_ch, out_ch, embed_dim, activation_fn))
if i == num_res_blocks and level < len(channel_mult) - 1:
block.append(Upsample(out_ch, out_ch))
self.output_blocks.append(block)Let's train this model now:
python part_e_cifar_deeper.py trainYou can also change additional hyperparameters, such as the learning rate, batch size, number of epochs, number of channels, activation function (relu, silu, leakyrelu, gelu), number of residual blocks, and channel multiplier:
python part_e_cifar_deeper.py train --batch-size 128 --lr 1e-3 --epochs 120 --model-channels 32 --activation silu --num-res-blocks 2 --channel-mult 2 2 2 1After training, we can sample a grid of images from the model with:
python part_e_cifar_deeper.py test --model-channels 32 --activation silu --num-res-blocks 2 --channel-mult 2 2 2 1Our resulting output looks like this:
We'll now add image augmentation to our training data, to improve the generalization of our model.
Specifically, here we'll use random horizontal flips.
transform = torchvision.transforms.Compose([
torchvision.transforms.RandomHorizontalFlip(),
torchvision.transforms.ToTensor(),
normalize
])
dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)Let's train this model now:
python part_f_cifar_hflips.py train --hflipYou can also change additional hyperparameters, such as the learning rate, batch size, number of epochs, number of channels, activation function (relu, silu, leakyrelu, gelu), number of residual blocks, and channel multiplier:
python part_f_cifar_hflips.py train --batch-size 128 --lr 1e-3 --epochs 120 --model-channels 32 --activation silu --num-res-blocks 2 --channel-mult 2 2 2 1 --hflipAfter training, we can sample a grid of images from the model with:
python part_f_cifar_hflips.py test --model-channels 32 --activation silu --num-res-blocks 2 --channel-mult 2 2 2 1Our resulting output looks like this:
We'll add dropout to our residual blocks. In the DDPM paper, they mentioned using a value of 0.1 improved the quality of the generated images.
class ResnetBlock(TimestepBlock):
def __init__(self, in_channels, out_channels, embed_channels, activation_fn=torch.nn.SiLU, dropout=0.1):
super(ResnetBlock, self).__init__()
self.in_layers = torch.nn.Sequential(
torch.nn.GroupNorm(32, in_channels),
activation_fn(),
torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1)
)
self.emb_layers = torch.nn.Sequential(
activation_fn(),
torch.nn.Linear(embed_channels, out_channels)
)
self.out_layers = torch.nn.Sequential(
torch.nn.GroupNorm(32, out_channels),
activation_fn(),
torch.nn.Dropout(dropout),
torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
)
if in_channels != out_channels:
self.shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1)
else:
self.shortcut = torch.nn.Identity()Let's train this model now:
python part_g_dropout.py train --hflipYou can also change additional hyperparameters, such as the learning rate, batch size, number of epochs, number of channels, activation function (relu, silu, leakyrelu, gelu), number of residual blocks, channel multiplier, and dropout:
python part_f_cifar_hflips.py train --batch-size 128 --lr 1e-3 --epochs 120 --model-channels 32 --activation silu --num-res-blocks 2 --channel-mult 2 2 2 1 --hflip --dropout 0.1After training, we can sample a grid of images from the model with:
python part_f_cifar_hflips.py test --model-channels 32 --activation silu --num-res-blocks 2 --channel-mult 2 2 2 1Our resulting output looks like this:
We'll now add attention blocks to our model, to improve the quality of the generated images.
We'll use convolutional layer to project the input to query, key, and value vectors, apply scaled dot-product attention along the channel dimension, and then project the output back to the input dimension.
class AttentionBlock(torch.nn.Module):
def __init__(self, in_channels):
super(AttentionBlock, self).__init__()
self.norm = torch.nn.GroupNorm(32, in_channels)
self.qkv = torch.nn.Conv1d(in_channels, in_channels*3, kernel_size=1)
self.out = torch.nn.Conv1d(in_channels, in_channels, kernel_size=1)
def forward(self, x):
b, c, h, w = x.shape
_input = x.view(b, c, -1)
x = self.norm(_input)
qkv = self.qkv(x).permute(0, 2, 1)
q, k, v = torch.chunk(qkv, 3, dim=2)
x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
x = x.permute(0, 2, 1)
x = self.out(x)
x = x + _input
return x.view(b, c, h, w)Let's train this model now:
python part_h_cifar_attn.py train --hflipYou can also change additional hyperparameters, such as the learning rate, batch size, number of epochs, number of channels, activation function (relu, silu, leakyrelu, gelu), number of residual blocks, channel multiplier, dropout, and attention resolutions (list of downscaled resolutions to apply attention to):
python part_h_cifar_attn.py train --batch-size 128 --lr 1e-3 --epochs 120 --hflip --model-channels 32 --activation silu --num-res-blocks 2 --channel-mult 2 2 2 1 --dropout 0.1 --attention_resolutions 2After training, we can sample a grid of images from the model with:
python part_h_cifar_attn.py test --model-channels 32 --activation silu --num-res-blocks 2 --channel-mult 2 2 2 1 --dropout 0.1 --attention_resolutions 2Our resulting output looks like this:
For our final training, we'll make a couple final improvements to our model:
We'll save an exponential moving average (EMA) of the model weights during training. This can reduce the noisiness of the weights, especially later in training
ema = torch.optim.swa_utils.AveragedModel(model, multi_avg_fn=get_ema_multi_avg_fn(ema_decay))We'll add a few more command-line args for training and inference
parser.add_argument('--gpu', type=int, default=None, help="GPU index")
parser.add_argument('--model', type=str, default='model.pth', help="Model file")
parser.add_argument('--save-checkpoints', action='store_true', help="Save model checkpoints")We'll now train our final model with EMA, using the same configuration as the DDPM paper
python main.py train --model-channels 128 --num-res-blocks 2 --channel-mult 1 2 2 2 --attention-resolutions 2 \
--batch-size 128 --lr 2e-4 --epochs 2000 --dropout 0.1 --hflip --grad-clip 1 --warmup 5000 \
--gpu 0 --save-checkpoints --log-interval 1 --output-dir cifar-unconditional
After training, we can sample a grid of images from the model with:
python main.py test --model-channels 128 --activation silu --num-res-blocks 2 --channel-mult 2 2 2 1 --dropout 0.1 --attention_resolutions 2 --gpu 0 --model model-ema.pth








