Skip to content

Commit 0af6faf

Browse files
committed
fix for this code: "imgs = encoder(self.fake_imgs[0])", need to load in all imgs with a loop #5
1 parent 3f510d6 commit 0af6faf

1 file changed

Lines changed: 2 additions & 2 deletions

File tree

code/trainer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -473,8 +473,8 @@ def train(self):
473473
encoder.fine_tune(fine_tune_encoder)
474474
encoder_optimizer = torch.optim.Adam(params=filter(lambda p: p.requires_grad, encoder.parameters()),
475475
lr=1e-4) if fine_tune_encoder else None
476-
imgs = encoder(self.fake_imgs[i].cuda())
477-
scores, caps_sorted, decode_lengths, alphas, sort_ind = decoder(imgs, caps, caplens)
476+
imgs = encoder(self.fake_imgs[i]).cuda()
477+
scores, caps_sorted, decode_lengths, alphas, sort_ind = decoder(imgs.cuda(), caps.cuda(), caplens.cuda()).cuda()
478478
targets = caps_sorted[:, 1:]
479479
scores, _ = pack_padded_sequence(scores, decode_lengths, batch_first=True).cuda()
480480
targets, _ = pack_padded_sequence(targets, decode_lengths, batch_first=True).cuda()

0 commit comments

Comments
 (0)