-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain.py
More file actions
104 lines (92 loc) · 3.21 KB
/
train.py
File metadata and controls
104 lines (92 loc) · 3.21 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import torch
from torch.utils.data import Dataset
import os
import numpy as np
import librosa
from torch.utils.data import DataLoader
from torch import nn
from tqdm import tqdm
import time
import wandb
import sys
from model import DvectorModel
from model import CNN
from model_res import ResNet_18
from train_dataset import TrainDataset
from test_dataset import TestDataset
from test import get_eer
def train():
device = torch.device('cuda:1' if torch.cuda.is_available() else print('No GPU'))
print(device)
# data paths
train_data_path = '/data/train'
test_data_path = '/data/test'
trial_path = '/data/trials/trials.txt'
# hyperparameters
classes = 1211
learning_rate = 0.001
embedding_size = 512
n_mels = 40
epochs = 150
batch_size = 512
loss_fn = nn.CrossEntropyLoss().to(device)
try:
start_epoch = int(sys.argv[1])
model = torch.load('model_epoch'+str(start_epoch)+'.pth').to(device)
except:
start_epoch = 0
model = CNN(
embedding_size=embedding_size,
class_size=classes,
n_mels=n_mels,
).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.99)
# wandb 설정
# os.system('wandb login be65d6ddace6bf4e2441a82af03c144eb85bbe65')
# wandb.init(project='dvector-original-s3v1', entity='dvector')
# wandb.config = {
# "learning_rate" : learning_rate,
# "epochs" : epochs,
# "batch_size" : batch_size
# }
# wandb.define_metric("loss")
# wandb.define_metric("eer")
# speaker_id to label 딕셔너리를 만듦.
speaker_ids = list(map(lambda x:int(x.split('id1')[1]), os.listdir(train_data_path)))
speaker_ids.sort()
labels = [i for i in range(len(speaker_ids))]
speaker_ids_to_labels = {speaker_ids[i]:labels[i] for i in range(len(speaker_ids))}
print(f'speakers: {len(speaker_ids)}')
train_data = TrainDataset(train_data_path,speaker_ids_to_labels)
train_dataloader = DataLoader(train_data, batch_size=batch_size, shuffle=True, num_workers=4)
test_data = TestDataset(test_data_path)
for epoch in range(1,epochs+1):
print('Epoch: ' + str(epoch))
model.train()
for (X, y) in tqdm(train_dataloader):
optimizer.zero_grad()
X = X.to(device)
y = y.to(device)
_ , pred = model(X)
loss = loss_fn(pred, y)
loss.backward()
optimizer.step()
# wandb.log({"loss":loss})
# scheduler.step()
test_data.update_embeddings(model,embedding_size,device)
eer, threshold = get_eer(
test_dataset=test_data,
test_path=test_data_path,
trial_path=trial_path
)
print('Threshold: ' + str(threshold))
print('EER: ' + str(eer))
# wandb.log({"eer": eer})
if epoch % 100 == 0:
checkpoint_path = 'model_epoch'+str(start_epoch+epoch+1)+'.pth'
torch.save(model,checkpoint_path)
checkpoint_path = 'model_epoch'+str(start_epoch+epoch+1)+'.pth'
torch.save(model,checkpoint_path)
if __name__ == '__main__':
train()