-
Notifications
You must be signed in to change notification settings - Fork 43
Expand file tree
/
Copy pathtrain.py
More file actions
318 lines (245 loc) · 12.9 KB
/
train.py
File metadata and controls
318 lines (245 loc) · 12.9 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
import os
import torch
import random
import shutil
import argparse
import numpy as np
from datetime import datetime
from pathlib import Path
from omegaconf import OmegaConf
from tqdm import tqdm
from glob import glob
from pesq import pesq
from joblib import Parallel, delayed
import soundfile as sf
import torch.distributed as dist
from torch.utils.tensorboard import SummaryWriter
from distributed_utils import reduce_value
from models.gtcrn_end2end import GTCRN as Model
from loss_factory import HybridLoss as Loss
from dataloader_dns3 import DNS3Dataset as Dataset
from scheduler import LinearWarmupCosineAnnealingLR as WarmupLR
seed = 43
random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
# torch.backends.cudnn.deterministic = True
def run(rank, config, args):
if args.world_size > 1:
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12354'
dist.init_process_group("nccl", rank=rank, world_size=args.world_size)
torch.cuda.set_device(rank)
dist.barrier()
args.rank = rank
args.device = torch.device(rank)
collate_fn = Dataset.collate_fn if hasattr(Dataset, "collate_fn") else None
# config['train_dataloader']['batch_size'] = config['train_dataloader']['batch_size'] // args.world_size
shuffle = False if args.world_size > 1 else True
train_dataset = Dataset(**config['train_dataset'])
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) if args.world_size > 1 else None
train_dataloader = torch.utils.data.DataLoader(dataset=train_dataset,
sampler=train_sampler,
**config['train_dataloader'],
shuffle=shuffle,
collate_fn=collate_fn)
validation_dataset = Dataset(**config['validation_dataset'])
validation_sampler = torch.utils.data.distributed.DistributedSampler(validation_dataset) if args.world_size > 1 else None
validation_dataloader = torch.utils.data.DataLoader(dataset=validation_dataset,
sampler=validation_sampler,
**config['validation_dataloader'],
shuffle=False,
collate_fn=collate_fn)
model = Model(**config['network_config']).to(args.device)
if args.world_size > 1:
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[rank])
optimizer = torch.optim.Adam(params=model.parameters(), **config['optimizer'])
# scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, **config['scheduler']['kwargs'])
# scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, **config['scheduler']['kwargs'])
# scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, **config['scheduler']['kwargs'])
scheduler = WarmupLR(optimizer, **config['scheduler']['kwargs'])
loss_func = Loss(**config['loss']).to(args.device)
trainer = Trainer(config=config, model=model,optimizer=optimizer, scheduler=scheduler, loss_func=loss_func,
train_dataloader=train_dataloader, validation_dataloader=validation_dataloader,
train_sampler=train_sampler, args=args)
trainer.train()
if args.world_size > 1:
dist.destroy_process_group()
class Trainer:
def __init__(self, config, model, optimizer, scheduler, loss_func,
train_dataloader, validation_dataloader, train_sampler, args):
self.config = config
self.model = model
self.optimizer = optimizer
self.scheduler = scheduler
self.loss_func = loss_func
self.train_dataloader = train_dataloader
self.validation_dataloader = validation_dataloader
self.train_sampler = train_sampler
self.rank = args.rank
self.device = args.device
self.world_size = args.world_size
# training config
config['DDP']['world_size'] = args.world_size
self.trainer_config = config['trainer']
self.epochs = self.trainer_config['epochs']
self.save_checkpoint_interval = self.trainer_config['save_checkpoint_interval']
self.clip_grad_norm_value = self.trainer_config['clip_grad_norm_value']
self.resume = self.trainer_config['resume']
if not self.resume:
self.exp_path = self.trainer_config['exp_path'] + '_' + datetime.now().strftime("%Y-%m-%d-%Hh%Mm")
else:
self.exp_path = self.trainer_config['exp_path'] + '_' + self.trainer_config['resume_datetime']
self.log_path = os.path.join(self.exp_path, 'logs')
self.checkpoint_path = os.path.join(self.exp_path, 'checkpoints')
self.sample_path = os.path.join(self.exp_path, 'val_samples')
self.code_path = os.path.join(self.exp_path, 'codes')
os.makedirs(self.log_path, exist_ok=True)
os.makedirs(self.checkpoint_path, exist_ok=True)
os.makedirs(self.sample_path, exist_ok=True)
os.makedirs(self.code_path, exist_ok=True)
# save the config and codes
if self.rank == 0:
data = OmegaConf.create(config)
OmegaConf.save(data, os.path.join(self.exp_path, 'config.yaml'))
shutil.copy2(__file__, self.exp_path)
for file in Path(__file__).parent.iterdir():
if file.is_file():
shutil.copy2(file, self.code_path)
shutil.copytree(Path(__file__).parent / 'models', Path(self.code_path) / 'models', dirs_exist_ok=True)
self.writer = SummaryWriter(self.log_path)
self.start_epoch = 1
self.best_score = 0
if self.resume:
self._resume_checkpoint()
def _set_train_mode(self):
self.model.train()
def _set_eval_mode(self):
self.model.eval()
def _save_checkpoint(self, epoch, score):
model_dict = self.model.module.state_dict() if self.world_size > 1 else self.model.state_dict()
state_dict = {'epoch': epoch,
'optimizer': self.optimizer.state_dict(),
'scheduler': self.scheduler.state_dict(),
'model': model_dict}
torch.save(state_dict, os.path.join(self.checkpoint_path, f'model_{str(epoch).zfill(3)}.tar'))
if score > self.best_score:
self.state_dict_best = state_dict.copy()
self.best_score = score
def _resume_checkpoint(self):
latest_checkpoints = sorted(glob(os.path.join(self.checkpoint_path, 'model_*.tar')))[-1]
map_location = self.device
checkpoint = torch.load(latest_checkpoints, map_location=map_location)
self.start_epoch = checkpoint['epoch'] + 1
self.optimizer.load_state_dict(checkpoint['optimizer'])
self.scheduler.load_state_dict(checkpoint['scheduler'])
if self.world_size > 1:
self.model.module.load_state_dict(checkpoint['model'])
else:
self.model.load_state_dict(checkpoint['model'])
def _train_epoch(self, epoch):
total_loss = 0
if hasattr(self.train_dataloader.dataset, "sample_data_per_epoch"):
self.train_dataloader.dataset.sample_data_per_epoch()
self.train_bar = tqdm(self.train_dataloader, ncols=110)
for step, (noisy, clean) in enumerate(self.train_bar, 1):
noisy = noisy.to(self.device)
clean = clean.to(self.device)
enhanced = self.model(noisy)
loss = self.loss_func(enhanced, clean)
if self.world_size > 1:
loss = reduce_value(loss)
total_loss += loss.item()
self.train_bar.desc = ' train[{}/{}][{}]'.format(
epoch, self.epochs + self.start_epoch-1, datetime.now().strftime("%Y-%m-%d-%H:%M"))
self.train_bar.postfix = 'train_loss={:.3f}'.format(total_loss / step)
self.optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.clip_grad_norm_value)
self.optimizer.step()
if self.config['scheduler']['update_interval'] == 'step':
self.scheduler.step()
if self.world_size > 1 and (self.device != torch.device("cpu")):
torch.cuda.synchronize(self.device)
if self.rank == 0:
self.writer.add_scalars('lr', {'lr': self.optimizer.param_groups[0]['lr']}, epoch)
self.writer.add_scalars('train_loss', {'train_loss': total_loss / step}, epoch)
@torch.inference_mode()
def _validation_epoch(self, epoch):
total_loss = 0
total_pesq_score = 0
self.validation_bar = tqdm(self.validation_dataloader, ncols=123)
for step, (noisy, clean) in enumerate(self.validation_bar, 1):
noisy = noisy.to(self.device)
clean = clean.to(self.device)
enhanced = self.model(noisy)
loss = self.loss_func(enhanced, clean)
if self.world_size > 1:
loss = reduce_value(loss)
total_loss += loss.item()
clean = clean.cpu().numpy()
enhanced = enhanced.detach().cpu().numpy()
pesq_score_batch = Parallel(n_jobs=-1)(
delayed(pesq)(16000, c, e, 'wb') for c, e in zip(clean, enhanced))
pesq_score = torch.tensor(pesq_score_batch, device=self.device).mean()
if self.world_size > 1:
pesq_score = reduce_value(pesq_score)
total_pesq_score += pesq_score
if self.rank == 0 and (epoch==1 or epoch %10 == 0) and step <= 3:
noisy_path = os.path.join(self.sample_path, 'sample_{}_noisy.wav'.format(step))
clean_path = os.path.join(self.sample_path, 'sample_{}_clean.wav'.format(step))
enhanced_path = os.path.join(self.sample_path, 'sample_{}_enh_epoch{}.wav'.format(step, str(epoch).zfill(3)))
if not os.path.exists(noisy_path):
noisy = noisy.cpu().numpy()
sf.write(noisy_path, noisy[0], samplerate=self.config['samplerate'])
sf.write(clean_path, clean[0], samplerate=self.config['samplerate'])
sf.write(enhanced_path, enhanced[0], samplerate=self.config['samplerate'])
self.validation_bar.desc = 'validate[{}/{}][{}]'.format(
epoch, self.epochs + self.start_epoch-1, datetime.now().strftime("%Y-%m-%d-%H:%M"))
self.validation_bar.postfix = 'valid_loss={:.3f}, pesq={:.4f}'.format(
total_loss / step, total_pesq_score / step)
if (self.world_size > 1) and (self.device != torch.device("cpu")):
torch.cuda.synchronize(self.device)
if self.rank == 0:
self.writer.add_scalars(
'val_loss', {'val_loss': total_loss / step,
'pesq': total_pesq_score / step}, epoch)
return total_loss / step, total_pesq_score / step
def train(self):
if self.resume:
self._resume_checkpoint()
for epoch in range(self.start_epoch, self.epochs + self.start_epoch):
if self.train_sampler is not None:
self.train_sampler.set_epoch(epoch)
self._set_train_mode()
self._train_epoch(epoch)
self._set_eval_mode()
valid_loss, score = self._validation_epoch(epoch)
if self.config['scheduler']['update_interval'] == 'epoch':
if self.config['scheduler']['use_plateau']:
self.scheduler.step(score)
else:
self.scheduler.step()
if (self.rank == 0) and (epoch % self.save_checkpoint_interval == 0):
self._save_checkpoint(epoch, score)
if self.rank == 0:
torch.save(self.state_dict_best,
os.path.join(self.checkpoint_path,
'best_model_{}.tar'.format(str(self.state_dict_best['epoch']).zfill(3))))
print('------------Training for {} epochs is done!------------'.format(self.epochs))
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('-C', '--config', default='configs/cfg_train.yaml')
parser.add_argument('-D', '--device', default='0', help='The index of the available devices, e.g. 0,1,2,3')
args = parser.parse_args()
os.environ["CUDA_VISIBLE_DEVICES"] = args.device
args.world_size = len(args.device.split(','))
config = OmegaConf.load(args.config)
if args.world_size > 1:
torch.multiprocessing.spawn(
run, args=(config, args,), nprocs=args.world_size, join=True)
else:
run(0, config, args)