-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathTrainModel.py
More file actions
181 lines (153 loc) · 5.84 KB
/
TrainModel.py
File metadata and controls
181 lines (153 loc) · 5.84 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
"""
Train a model, validate it, and test it.
The code was provided as a part of the course material, but was modified.
Modifications from original:
1. Save epoch number, training loss, training accuracy, validation loss, & validation accuracy to a csv file.
2. Compute test loss and accuracy for the model weights with the highest validation accuracy and save to a csv file.
Parameters
----------
model : torch.nn.Module
The neural network model to be trained.
criterion : torch.nn.Module
The loss function.
optimizer : torch.optim.Optimizer
The optimization algorithm.
scheduler : torch.optim.lr_scheduler._LRScheduler
The learning rate scheduler.
train_loader : torch.utils.data.DataLoader
DataLoader for the training dataset.
val_loader : torch.utils.data.DataLoader
DataLoader for the validation dataset.
test_loader : torch.utils.data.DataLoader
DataLoader for the test dataset.
output_dir : pathlib.Path
Directory where the output files will be saved.
num_epochs : int, optional
Number of epochs to train the model (default is 20).
Returns
-------
model : torch.nn.Module
The trained model with the best validation accuracy.
Notes
-----
- The function saves the training progress (epoch number, training loss, training accuracy, validation loss, and validation accuracy) to a CSV file.
- The function computes the test accuracy for the model weights with the highest validation accuracy and saves it to a file.
"""
import copy
import numpy as np
import pandas as pd
import torch
def train_model(
model,
criterion,
optimizer,
scheduler,
train_loader,
val_loader,
test_loader,
output_dir,
num_epochs=20,
):
"""
Train a deep learning model with given parameters and data loaders.
Parameters
----------
model : torch.nn.Module
The neural network model to be trained.
criterion : torch.nn.Module
The loss function.
optimizer : torch.optim.Optimizer
The optimization algorithm.
scheduler : torch.optim.lr_scheduler._LRScheduler
The learning rate scheduler.
train_loader : torch.utils.data.DataLoader
DataLoader for the training data.
val_loader : torch.utils.data.DataLoader
DataLoader for the validation data.
test_loader : torch.utils.data.DataLoader
DataLoader for the test data.
output_dir : pathlib.Path
Directory where the output files (training progress and test accuracy) will be saved.
num_epochs : int, optional
Number of epochs to train the model (default is 20).
Returns
-------
model : torch.nn.Module
The trained model with the best validation accuracy.
Notes
-----
The function trains the model using the training data, evaluates it on the validation data,
and saves the model with the highest validation accuracy. It also saves the training progress
and test accuracy to the specified output directory.
"""
# Get correct device
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = model.to(device)
highest_accuracy = 0.0
best_model_weights = copy.deepcopy(model.state_dict())
def update_model(loader, training):
current_loss = 0.0
current_correct = 0
for inputs, labels in loader:
inputs = inputs.to(device)
labels = labels.to(device)
with torch.set_grad_enabled(training):
outputs = model(inputs)
_, preds = torch.max(outputs, 1)
loss = criterion(outputs, labels)
if training:
loss.backward()
optimizer.step()
current_loss += loss.item() * inputs.size(0)
current_correct += torch.sum(preds == labels.data)
return current_loss, current_correct
# Progress data
progress_data = []
for epoch in range(num_epochs):
print("Epoch: {}".format(epoch + 1))
# train phase
model.train()
train_loss, train_correct = update_model(train_loader, True)
scheduler.step()
epoch_train_loss = train_loss / len(train_loader)
epoch_train_accuracy = float(train_correct) / (len(train_loader) * train_loader.batch_size)
print("Phase: Train Loss: {} Accuracy: {}".format(epoch_train_loss, epoch_train_accuracy))
# val phase
model.eval()
val_loss, val_correct = update_model(val_loader, False)
epoch_val_loss = val_loss / len(val_loader)
epoch_val_accuracy = float(val_correct) / (len(val_loader) * val_loader.batch_size)
print("Phase: Validation Loss: {} Accuracy: {}".format(epoch_val_loss, epoch_val_accuracy))
if epoch_val_accuracy > highest_accuracy:
highest_accuracy = epoch_val_accuracy
best_model_weights = copy.deepcopy(model.state_dict())
# Progress data
progress_data.append(
[
epoch,
epoch_train_loss,
epoch_train_accuracy,
epoch_val_loss,
epoch_val_accuracy,
]
)
# Save progress data to csv
pd.DataFrame(
progress_data,
columns=[
"epoch",
"training_loss",
"training_accuracy",
"validation_loss",
"validation_accuracy",
],
).to_csv(output_dir / "training_progress.csv", index=False)
# Best model. Compute test loss and accuracy.
print("Training finished. Highest validation accuracy: {}".format(highest_accuracy))
model.load_state_dict(best_model_weights)
model.eval()
_, test_correct = update_model(test_loader, False)
test_accuracy = float(test_correct) / (len(test_loader) * test_loader.batch_size)
with open(output_dir / "test_accuracy.dat", encoding="utf-8", mode="w") as fid:
fid.write(str(test_accuracy))
return model