-
Notifications
You must be signed in to change notification settings - Fork 1
Sourcery Starbot ⭐ refactored undo76/PyTorch-GAN #1
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -33,14 +33,13 @@ | |
|
|
||
| img_shape = (opt.channels, opt.img_size, opt.img_size) | ||
|
|
||
| cuda = True if torch.cuda.is_available() else False | ||
| cuda = bool(torch.cuda.is_available()) | ||
|
|
||
|
|
||
| def reparameterization(mu, logvar): | ||
| std = torch.exp(logvar / 2) | ||
| sampled_z = Variable(Tensor(np.random.normal(0, 1, (mu.size(0), opt.latent_dim)))) | ||
| z = sampled_z * std + mu | ||
| return z | ||
| return sampled_z * std + mu | ||
|
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
|
|
||
|
|
||
| class Encoder(nn.Module): | ||
|
|
@@ -63,8 +62,7 @@ def forward(self, img): | |
| x = self.model(img_flat) | ||
| mu = self.mu(x) | ||
| logvar = self.logvar(x) | ||
| z = reparameterization(mu, logvar) | ||
| return z | ||
| return reparameterization(mu, logvar) | ||
|
Comment on lines
-66
to
+65
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
|
|
||
|
|
||
| class Decoder(nn.Module): | ||
|
|
@@ -83,8 +81,7 @@ def __init__(self): | |
|
|
||
| def forward(self, z): | ||
| img_flat = self.model(z) | ||
| img = img_flat.view(img_flat.shape[0], *img_shape) | ||
| return img | ||
| return img_flat.view(img_flat.shape[0], *img_shape) | ||
|
Comment on lines
-86
to
+84
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
|
|
||
|
|
||
| class Discriminator(nn.Module): | ||
|
|
@@ -101,8 +98,7 @@ def __init__(self): | |
| ) | ||
|
|
||
| def forward(self, z): | ||
| validity = self.model(z) | ||
| return validity | ||
| return self.model(z) | ||
|
Comment on lines
-104
to
+101
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
|
|
||
|
|
||
| # Use binary cross-entropy loss | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -31,7 +31,7 @@ | |
| opt = parser.parse_args() | ||
| print(opt) | ||
|
|
||
| cuda = True if torch.cuda.is_available() else False | ||
| cuda = bool(torch.cuda.is_available()) | ||
|
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Lines
|
||
|
|
||
|
|
||
| def weights_init_normal(m): | ||
|
|
@@ -70,8 +70,7 @@ def forward(self, noise, labels): | |
| gen_input = torch.mul(self.label_emb(labels), noise) | ||
| out = self.l1(gen_input) | ||
| out = out.view(out.shape[0], 128, self.init_size, self.init_size) | ||
| img = self.conv_blocks(out) | ||
| return img | ||
| return self.conv_blocks(out) | ||
|
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
|
|
||
|
|
||
| class Discriminator(nn.Module): | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -32,7 +32,7 @@ | |
|
|
||
| img_shape = (opt.channels, opt.img_size, opt.img_size) | ||
|
|
||
| cuda = True if torch.cuda.is_available() else False | ||
| cuda = bool(torch.cuda.is_available()) | ||
|
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Lines
|
||
|
|
||
|
|
||
| def weights_init_normal(m): | ||
|
|
@@ -68,8 +68,7 @@ def __init__(self): | |
| def forward(self, noise): | ||
| out = self.l1(noise) | ||
| out = out.view(out.shape[0], 128, self.init_size, self.init_size) | ||
| img = self.conv_blocks(out) | ||
| return img | ||
| return self.conv_blocks(out) | ||
|
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
|
|
||
|
|
||
| class Discriminator(nn.Module): | ||
|
|
@@ -189,7 +188,7 @@ def forward(self, img): | |
| diff = torch.mean(gamma * d_loss_real - d_loss_fake) | ||
|
|
||
| # Update weight term for fake samples | ||
| k = k + lambda_k * diff.item() | ||
| k += lambda_k * diff.item() | ||
|
Comment on lines
-192
to
+191
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Lines
|
||
| k = min(max(k, 0), 1) # Constraint to interval [0, 1] | ||
|
|
||
| # Update convergence metric | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -34,7 +34,7 @@ | |
|
|
||
| img_shape = (opt.channels, opt.img_size, opt.img_size) | ||
|
|
||
| cuda = True if torch.cuda.is_available() else False | ||
| cuda = bool(torch.cuda.is_available()) | ||
|
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Lines
|
||
|
|
||
|
|
||
| class Generator(nn.Module): | ||
|
|
@@ -78,8 +78,7 @@ def __init__(self): | |
|
|
||
| def forward(self, img): | ||
| img_flat = img.view(img.shape[0], -1) | ||
| validity = self.model(img_flat) | ||
| return validity | ||
| return self.model(img_flat) | ||
|
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
|
|
||
|
|
||
| def boundary_seeking_loss(y_pred, y_true): | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -45,7 +45,7 @@ | |
| os.makedirs("images/%s" % opt.dataset_name, exist_ok=True) | ||
| os.makedirs("saved_models/%s" % opt.dataset_name, exist_ok=True) | ||
|
|
||
| cuda = True if torch.cuda.is_available() else False | ||
| cuda = bool(torch.cuda.is_available()) | ||
|
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Lines
|
||
|
|
||
| input_shape = (opt.channels, opt.img_height, opt.img_width) | ||
|
|
||
|
|
@@ -125,8 +125,7 @@ def sample_images(batches_done): | |
| def reparameterization(mu, logvar): | ||
| std = torch.exp(logvar / 2) | ||
| sampled_z = Variable(Tensor(np.random.normal(0, 1, (mu.size(0), opt.latent_dim)))) | ||
| z = sampled_z * std + mu | ||
| return z | ||
| return sampled_z * std + mu | ||
|
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
|
|
||
|
|
||
| # ---------- | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -154,8 +154,7 @@ def discriminator_block(in_filters, out_filters, normalize=True): | |
|
|
||
| def compute_loss(self, x, gt): | ||
| """Computes the MSE between model output and scalar gt""" | ||
| loss = sum([torch.mean((out - gt) ** 2) for out in self.forward(x)]) | ||
| return loss | ||
| return sum(torch.mean((out - gt) ** 2) for out in self.forward(x)) | ||
|
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
|
|
||
| def forward(self, x): | ||
| outputs = [] | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -36,7 +36,7 @@ | |
| opt = parser.parse_args() | ||
| print(opt) | ||
|
|
||
| cuda = True if torch.cuda.is_available() else False | ||
| cuda = bool(torch.cuda.is_available()) | ||
|
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Lines
|
||
|
|
||
| input_shape = (opt.channels, opt.img_size, opt.img_size) | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -38,8 +38,7 @@ def __init__(self, in_size, out_size, dropout=0.0): | |
|
|
||
| def forward(self, x, skip_input): | ||
| x = self.model(x) | ||
| out = torch.cat((x, skip_input), 1) | ||
| return out | ||
| return torch.cat((x, skip_input), 1) | ||
|
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
|
|
||
|
|
||
| class Generator(nn.Module): | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -33,7 +33,7 @@ | |
|
|
||
| img_shape = (opt.channels, opt.img_size, opt.img_size) | ||
|
|
||
| cuda = True if torch.cuda.is_available() else False | ||
| cuda = bool(torch.cuda.is_available()) | ||
|
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Lines
|
||
|
|
||
|
|
||
| class Generator(nn.Module): | ||
|
|
@@ -87,8 +87,7 @@ def __init__(self): | |
| def forward(self, img, labels): | ||
| # Concatenate label embedding and image to produce input | ||
| d_in = torch.cat((img.view(img.size(0), -1), self.label_embedding(labels)), -1) | ||
| validity = self.model(d_in) | ||
| return validity | ||
| return self.model(d_in) | ||
|
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
|
|
||
|
|
||
| # Loss functions | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -105,13 +105,7 @@ def calc_gradient_penalty(netD, real_data, generated_data): | |
| # Weight Initializer | ||
| def initialize_weights(net): | ||
| for m in net.modules(): | ||
| if isinstance(m, nn.Conv2d): | ||
| m.weight.data.normal_(0, 0.02) | ||
| m.bias.data.zero_() | ||
| elif isinstance(m, nn.ConvTranspose2d): | ||
| m.weight.data.normal_(0, 0.02) | ||
| m.bias.data.zero_() | ||
| elif isinstance(m, nn.Linear): | ||
| if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.Linear)): | ||
| m.weight.data.normal_(0, 0.02) | ||
| m.bias.data.zero_() | ||
|
Comment on lines
-108
to
116
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
|
|
||
|
|
@@ -292,9 +286,7 @@ def __init__(self, wass_metric=False, verbose=False): | |
| print(self.model) | ||
|
|
||
| def forward(self, img): | ||
| # Get output | ||
| validity = self.model(img) | ||
| return validity | ||
| return self.model(img) | ||
|
Comment on lines
-295
to
+289
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
This removes the following comments ( why? ): |
||
|
|
||
|
|
||
|
|
||
|
|
@@ -323,7 +315,7 @@ def forward(self, img): | |
|
|
||
| x_shape = (channels, img_size, img_size) | ||
|
|
||
| cuda = True if torch.cuda.is_available() else False | ||
| cuda = bool(torch.cuda.is_available()) | ||
|
Comment on lines
-326
to
+318
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Lines
|
||
| device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') | ||
|
|
||
| # Loss function | ||
|
|
@@ -343,7 +335,7 @@ def forward(self, img): | |
| bce_loss.cuda() | ||
| xe_loss.cuda() | ||
| mse_loss.cuda() | ||
|
|
||
| Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor | ||
|
|
||
| # Configure data loader | ||
|
|
@@ -395,9 +387,11 @@ def forward(self, img): | |
|
|
||
| # Training loop | ||
| print('\nBegin training session with %i epochs...\n'%(n_epochs)) | ||
| # Set number of examples for cycle calcs | ||
| n_sqrt_samp = 5 | ||
| for epoch in range(n_epochs): | ||
| for i, (imgs, itruth_label) in enumerate(dataloader): | ||
|
|
||
| # Ensure generator/encoder are trainable | ||
| generator.train() | ||
| encoder.train() | ||
|
|
@@ -406,28 +400,28 @@ def forward(self, img): | |
| generator.zero_grad() | ||
| encoder.zero_grad() | ||
| discriminator.zero_grad() | ||
|
|
||
| # Configure input | ||
| real_imgs = Variable(imgs.type(Tensor)) | ||
|
|
||
| # --------------------------- | ||
| # Train Generator + Encoder | ||
| # --------------------------- | ||
|
|
||
| optimizer_GE.zero_grad() | ||
|
|
||
| # Sample random latent variables | ||
| zn, zc, zc_idx = sample_z(shape=imgs.shape[0], | ||
| latent_dim=latent_dim, | ||
| n_c=n_c) | ||
|
|
||
| # Generate a batch of images | ||
| gen_imgs = generator(zn, zc) | ||
|
|
||
| # Discriminator output from real and generated samples | ||
| D_gen = discriminator(gen_imgs) | ||
| D_real = discriminator(real_imgs) | ||
|
|
||
| # Step for Generator & Encoder, n_skip_iter times less than for discriminator | ||
| if (i % n_skip_iter == 0): | ||
| # Encode the generated images | ||
|
|
@@ -463,7 +457,7 @@ def forward(self, img): | |
|
|
||
| # Wasserstein GAN loss w/gradient penalty | ||
| d_loss = torch.mean(D_real) - torch.mean(D_gen) + grad_penalty | ||
|
|
||
| else: | ||
| # Vanilla GAN loss | ||
| fake = Variable(Tensor(gen_imgs.size(0), 1).fill_(0.0), requires_grad=False) | ||
|
|
@@ -484,9 +478,7 @@ def forward(self, img): | |
| generator.eval() | ||
| encoder.eval() | ||
|
|
||
| # Set number of examples for cycle calcs | ||
| n_sqrt_samp = 5 | ||
| n_samp = n_sqrt_samp * n_sqrt_samp | ||
| n_samp = n_sqrt_samp**2 | ||
|
|
||
|
|
||
| ## Cycle through test real -> enc -> gen | ||
|
|
@@ -499,7 +491,7 @@ def forward(self, img): | |
| img_mse_loss = mse_loss(t_imgs, teg_imgs) | ||
| # Save img reco cycle loss | ||
| c_i.append(img_mse_loss.item()) | ||
|
|
||
|
|
||
| ## Cycle through randomly sampled encoding -> generator -> encoder | ||
| zn_samp, zc_samp, zc_samp_idx = sample_z(shape=n_samp, | ||
|
|
@@ -518,7 +510,7 @@ def forward(self, img): | |
| # Save latent space cycle losses | ||
| c_zn.append(lat_mse_loss.item()) | ||
| c_zc.append(lat_xe_loss.item()) | ||
|
|
||
| # Save cycled and generated examples! | ||
| r_imgs, i_label = real_imgs.data[:n_samp], itruth_label[:n_samp] | ||
| e_zn, e_zc, e_zc_logits = encoder(r_imgs) | ||
|
|
@@ -529,7 +521,7 @@ def forward(self, img): | |
| save_image(gen_imgs_samp.data[:n_samp], | ||
| 'images/gen_%06i.png' %(epoch), | ||
| nrow=n_sqrt_samp, normalize=True) | ||
|
|
||
| ## Generate samples for specified classes | ||
| stack_imgs = [] | ||
| for idx in range(n_c): | ||
|
|
@@ -542,7 +534,7 @@ def forward(self, img): | |
| # Generate sample instances | ||
| gen_imgs_samp = generator(zn_samp, zc_samp) | ||
|
|
||
| if (len(stack_imgs) == 0): | ||
| if not stack_imgs: | ||
| stack_imgs = gen_imgs_samp | ||
| else: | ||
| stack_imgs = torch.cat((stack_imgs, gen_imgs_samp), 0) | ||
|
|
@@ -551,15 +543,15 @@ def forward(self, img): | |
| save_image(stack_imgs, | ||
| 'images/gen_classes_%06i.png' %(epoch), | ||
| nrow=n_c, normalize=True) | ||
|
|
||
|
|
||
| print ("[Epoch %d/%d] \n"\ | ||
| "\tModel Losses: [D: %f] [GE: %f]" % (epoch, | ||
| n_epochs, | ||
| d_loss.item(), | ||
| ge_loss.item()) | ||
| ) | ||
|
|
||
| print("\tCycle Losses: [x: %f] [z_n: %f] [z_c: %f]"%(img_mse_loss.item(), | ||
| lat_mse_loss.item(), | ||
| lat_xe_loss.item()) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -36,7 +36,7 @@ | |
|
|
||
| img_shape = (opt.channels, opt.img_size, opt.img_size) | ||
|
|
||
| cuda = True if torch.cuda.is_available() else False | ||
| cuda = bool(torch.cuda.is_available()) | ||
|
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Lines
|
||
|
|
||
|
|
||
| def weights_init_normal(m): | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -105,9 +105,7 @@ def download(self): | |
| os.makedirs(os.path.join(self.root, self.raw_folder)) | ||
| os.makedirs(os.path.join(self.root, self.processed_folder)) | ||
| except OSError as e: | ||
| if e.errno == errno.EEXIST: | ||
| pass | ||
| else: | ||
| if e.errno != errno.EEXIST: | ||
|
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Function
|
||
| raise | ||
|
|
||
| # download pkl files | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -8,6 +8,7 @@ | |
| 4. Run the sript using command 'python3 context_encoder.py' | ||
| """ | ||
|
|
||
|
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Lines
|
||
|
|
||
| import argparse | ||
| import os | ||
| import numpy as np | ||
|
|
@@ -46,7 +47,7 @@ | |
| opt = parser.parse_args() | ||
| print(opt) | ||
|
|
||
| cuda = True if torch.cuda.is_available() else False | ||
| cuda = bool(torch.cuda.is_available()) | ||
|
|
||
| # Calculate output of image discriminator (PatchGAN) | ||
| patch_h, patch_w = int(opt.mask_size / 2 ** 3), int(opt.mask_size / 2 ** 3) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Lines
36-36refactored with the following changes:boolean-if-exp-identity)