-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest.py
More file actions
74 lines (59 loc) · 2.42 KB
/
test.py
File metadata and controls
74 lines (59 loc) · 2.42 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import torch
import os
# import from other python files
import config
from io_utils import parse_args, get_best_file , get_assigned_file
from model import IDC_Grading_Model, test
from dataset import FBCG
from get_temp import get_temp
if __name__=='__main__':
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# initialise result dir
result_dir = config.RECORD_DIR
if not os.path.exists(result_dir):
os.makedirs(result_dir)
params = parse_args('test')
# initialise testing parsing arguments
feature_extractor = params.feature_extractor
train_aug = params.train_aug
unzip = params.unzip
batch_size = params.batch_size
# setup templates image path
if params.sn in ['reinhard', 'macenko', 'vahadane']:
# get temp images if required
get_temp(config.DATA_PATH)
if params.temp in ['Temp1', 'Temp2', 'Temp3', 'Temp4', 'Temp5']:
temp_dir = config.TEMP_DIR + f'/{params.temp}.png'
else:
raise Exception('wrong template image')
else:
temp_dir = None
params.temp = 'None'
test_dataloader, y_true = FBCG(data_path = config.DATA_PATH, zip_path = config.ZIP_PATH, augmentation = train_aug, unzip = unzip, batch_size=batch_size, test_mode=True, sn = params.sn, temp_dir = temp_dir)
# initialise CNN model pass to device
model = IDC_Grading_Model(feature_extractor).to(device)
# initialise checkpoint directory
checkpoint_dir = '%s/checkpoints/%s' %(config.SAVE_DIR, params.feature_extractor)
if params.train_aug:
checkpoint_dir += '_aug'
# if stain normalisation is applied
if params.sn != 'none':
params.checkpoint_dir += f'_{params.sn}'
params.checkpoint_dir += f'_{params.temp}'
# load the model with best weights or from specific epochs
if params.save_iter != -1:
modelfile = get_assigned_file(checkpoint_dir,params.save_iter)
else:
modelfile = get_best_file(checkpoint_dir)
if modelfile is not None:
tmp = torch.load(modelfile)
model.load_state_dict(tmp['state'])
# model inference
acc = test(y_true, test_dataloader, model, device)
print(f"Model: {feature_extractor}, Test acc: {(acc):>0.2f}%")
# save results into results.txt
with open(os.path.join(result_dir, 'results.txt'), 'a') as f:
aug_str = '-aug' if params.train_aug else ''
exp_setting = f'{feature_extractor} -{params.sn} -{params.temp} {aug_str}'
acc_str = f'Test acc: {(acc):>0.2f}%'
f.write( 'Setting: %s, Acc: %s \n' %(exp_setting,acc_str))