-
Notifications
You must be signed in to change notification settings - Fork 25
Expand file tree
/
Copy pathReinforceTrainer.py
More file actions
184 lines (151 loc) · 7.06 KB
/
ReinforceTrainer.py
File metadata and controls
184 lines (151 loc) · 7.06 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
import datetime
import math
import os
import time
import torch
import lib
class ReinforceTrainer(object):
def __init__(self, actor, critic, train_data, eval_data, metrics, dicts, optim, critic_optim, opt):
self.actor = actor
self.critic = critic
self.train_data = train_data
self.eval_data = eval_data
self.evaluator = lib.Evaluator(actor, metrics, dicts, opt)
self.actor_loss_func = metrics["nmt_loss"]
self.critic_loss_func = metrics["critic_loss"]
self.sent_reward_func = metrics["sent_reward"]
self.dicts = dicts
self.optim = optim
self.critic_optim = critic_optim
self.max_length = opt.max_predict_length
self.pert_func = opt.pert_func
self.opt = opt
print("")
print(actor)
print("")
print(critic)
def train(self, start_epoch, end_epoch, pretrain_critic, start_time=None):
if start_time is None:
self.start_time = time.time()
else:
self.start_time = start_time
self.optim.last_loss = self.critic_optim.last_loss = None
self.optim.set_lr(self.opt.reinforce_lr)
# Use large learning rate for critic during pre-training.
if pretrain_critic:
self.critic_optim.set_lr(1e-3)
else:
self.critic_optim.set_lr(self.opt.reinforce_lr)
for epoch in range(start_epoch, end_epoch + 1):
print("")
print("* REINFORCE epoch *")
print("Actor optim lr: %g; Critic optim lr: %g" %
(self.optim.lr, self.critic_optim.lr))
if pretrain_critic:
print("Pretrain critic...")
no_update = self.opt.no_update and (not pretrain_critic) and \
(epoch == start_epoch)
if no_update: print("No update...")
train_reward, critic_loss = self.train_epoch(epoch, pretrain_critic, no_update)
print("Train sentence reward: %.2f" % (train_reward * 100))
print("Critic loss: %g" % critic_loss)
valid_loss, valid_sent_reward, valid_corpus_reward = self.evaluator.eval(self.eval_data)
valid_ppl = math.exp(min(valid_loss, 100))
print("Validation perplexity: %.2f" % valid_ppl)
print("Validation sentence reward: %.2f" % (valid_sent_reward * 100))
print("Validation corpus reward: %.2f" %
(valid_corpus_reward * 100))
if no_update: break
self.optim.updateLearningRate(-valid_sent_reward, epoch)
# Actor and critic use the same lr when jointly trained.
# TODO: using small lr for critic is better?
if not pretrain_critic:
self.critic_optim.set_lr(self.optim.lr)
checkpoint = {
"model": self.actor,
"critic": self.critic,
"dicts": self.dicts,
"opt": self.opt,
"epoch": epoch,
"optim": self.optim,
"critic_optim": self.critic_optim
}
model_name = os.path.join(self.opt.save_dir, "model_%d" % epoch)
if pretrain_critic:
model_name += "_pretrain"
else:
model_name += "_reinforce"
model_name += ".pt"
torch.save(checkpoint, model_name)
print("Save model as %s" % model_name)
def train_epoch(self, epoch, pretrain_critic, no_update):
self.actor.train()
total_reward, report_reward = 0, 0
total_critic_loss, report_critic_loss = 0, 0
total_sents, report_sents = 0, 0
total_words, report_words = 0, 0
last_time = time.time()
for i in range(len(self.train_data)):
batch = self.train_data[i]
sources = batch[0]
targets = batch[1]
batch_size = targets.size(1)
self.actor.zero_grad()
self.critic.zero_grad()
# Sample translations
attention_mask = sources[0].data.eq(lib.Constants.PAD).t()
self.actor.decoder.attn.applyMask(attention_mask)
samples, outputs = self.actor.sample(batch, self.max_length)
# Calculate rewards
rewards, samples = self.sent_reward_func(samples.t().tolist(), targets.data.t().tolist())
reward = sum(rewards)
# Perturb rewards (if specified).
if self.pert_func is not None:
rewards = self.pert_func(rewards)
samples = torch.LongTensor(samples).t().contiguous()
rewards = torch.FloatTensor([rewards] * samples.size(0)).contiguous()
if self.opt.cuda:
samples = samples.to(self.opt.device)
rewards = rewards.to(self.opt.device)
samples.requires_grad_(False)
rewards.requires_grad_(False)
# Update critic.
critic_weights = samples.ne(lib.Constants.PAD).float()
num_words = critic_weights.data.sum()
if not no_update:
baselines = self.critic((sources, samples), eval=False, regression=True)
critic_loss = self.critic.backward(
baselines, rewards, critic_weights, num_words, self.critic_loss_func, regression=True)
self.critic_optim.step()
else:
critic_loss = 0
# Update actor
if not pretrain_critic and not no_update:
# Subtract baseline from reward
norm_rewards = (rewards - baselines).detach()
actor_weights = norm_rewards * critic_weights
# TODO: can use PyTorch reinforce() here but that function is a black box.
# This is an alternative way where you specify an objective that gives the same gradient
# as the policy gradient's objective, which looks much like weighted log-likelihood.
actor_loss = self.actor.backward(outputs, samples, actor_weights, 1, self.actor_loss_func)
self.optim.step()
# Gather stats
total_reward += reward
report_reward += reward
total_sents += batch_size
report_sents += batch_size
total_critic_loss += critic_loss
report_critic_loss += critic_loss
total_words += num_words
report_words += num_words
if i % self.opt.log_interval == 0 and i > 0:
print("""Epoch %3d, %6d/%d batches;
actor reward: %.4f; critic loss: %f; %5.0f tokens/s; %s elapsed""" %
(epoch, i, len(self.train_data),
(report_reward / report_sents) * 100,
report_critic_loss / report_words,
report_words / (time.time() - last_time),
str(datetime.timedelta(seconds=int(time.time() - self.start_time)))))
report_reward = report_sents = report_critic_loss = report_words = 0
last_time = time.time()
return total_reward / total_sents, total_critic_loss / total_words