-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathHelpers.py
More file actions
25 lines (24 loc) · 830 Bytes
/
Helpers.py
File metadata and controls
25 lines (24 loc) · 830 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
from IPython.display import clear_output
from matplotlib import pyplot as plt
import collections
import numpy as np
import pandas as pd
from PIL import Image
import torch
import torchvision
from sklearn.metrics import accuracy_score
def live_plot(loss, train_acc, valid_acc=None, figsize=(7,5), title=''):
clear_output(wait=True)
fig, ax1 = plt.subplots(figsize=figsize)
ax1.plot(loss, label='Training Loss', color='red')
ax1.legend(loc='lower left')
ax1.set_ylabel('Cross Entropy Loss')
ax2 = ax1.twinx()
ax2.plot(train_acc, label='Training Accuracy', color='green')
if valid_acc is not None:
ax2.plot(valid_acc, label='Validation Accuracy', color='blue')
ax2.legend(loc='lower right')
ax2.set_ylabel('Accuracy (%)')
ax2.set_xlabel('Epoch')
plt.title(title)
plt.show()