-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain_synapse_lomix.py
More file actions
143 lines (126 loc) · 7.37 KB
/
train_synapse_lomix.py
File metadata and controls
143 lines (126 loc) · 7.37 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
import argparse
import logging
import os
import random
import numpy as np
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
from lib.networks import PVT_CASCADE, EMCADNet
from trainer import trainer_synapse
parser = argparse.ArgumentParser()
parser.add_argument('--root_path', type=str,
default='./data/synapse/train_npz_new', help='root dir for data')
parser.add_argument('--volume_path', type=str,
default='./data/synapse/test_vol_h5_new', help='root dir for validation volume data')
parser.add_argument('--dataset', type=str,
default='Synapse', help='experiment_name')
parser.add_argument('--list_dir', type=str,
default='./lists/lists_Synapse', help='list dir')
parser.add_argument('--num_classes', type=int,
default=9, help='output channel of network')
# network related parameters
parser.add_argument('--encoder', type=str,
default='pvt_v2_b2', help='Name of encoder: pvt_v2_b2, pvt_v2_b0, resnet18, resnet34 ...')
parser.add_argument('--expansion_factor', type=int,
default=2, help='expansion factor in MSCB block')
parser.add_argument('--kernel_sizes', type=int, nargs='+',
default=[1, 3, 5], help='multi-scale kernel sizes in MSDC block')
parser.add_argument('--lgag_ks', type=int,
default=3, help='Kernel size in LGAG')
parser.add_argument('--activation_mscb', type=str,
default='relu6', help='activation used in MSCB: relu6 or relu')
parser.add_argument('--no_dw_parallel', action='store_true',
default=False, help='use this flag to disable depth-wise parallel convolutions')
parser.add_argument('--concatenation', action='store_true',
default=False, help='use this flag to concatenate feature maps in MSDC block')
parser.add_argument('--no_pretrain', action='store_true',
default=False, help='use this flag to turn off loading pretrained enocder weights')
parser.add_argument('--supervision', type=str,
default='lomix', help='loss supervision: lomix, mutation, deep_supervision or last_layer')
parser.add_argument('--max_iterations', type=int,
default=50000, help='maximum epoch number to train')
parser.add_argument('--max_epochs', type=int,
default=300, help='maximum epoch number to train')
parser.add_argument('--batch_size', type=int,
default=6, help='batch_size per gpu')
parser.add_argument('--base_lr', type=float, default=0.0001,
help='segmentation network learning rate')
parser.add_argument('--img_size', type=int,
default=224, help='input patch size of network input')
parser.add_argument('--n_gpu', type=int, default=1, help='total gpu')
parser.add_argument('--deterministic', type=int, default=1,
help='whether use deterministic training')
parser.add_argument('--seed', type=int,
default=2222, help='random seed')
args = parser.parse_args()
if __name__ == "__main__":
if not args.deterministic:
cudnn.benchmark = True
cudnn.deterministic = False
else:
cudnn.benchmark = False
cudnn.deterministic = True
random.seed(args.seed)
np.random.seed(args.seed)
torch.manual_seed(args.seed)
torch.cuda.manual_seed(args.seed)
dataset_name = args.dataset
dataset_config = {
'Synapse': {
'root_path': args.root_path,
'volume_path': args.volume_path,
'list_dir': args.list_dir,
'num_classes': args.num_classes,
'z_spacing': 1,
},
}
args.num_classes = dataset_config[dataset_name]['num_classes']
args.root_path = dataset_config[dataset_name]['root_path']
args.volume_path = dataset_config[dataset_name]['volume_path']
args.z_spacing = dataset_config[dataset_name]['z_spacing']
args.list_dir = dataset_config[dataset_name]['list_dir']
print(args.no_pretrain)
if args.concatenation:
aggregation = 'concat'
else:
aggregation = 'add'
if args.no_dw_parallel:
dw_mode = 'series'
else:
dw_mode = 'parallel'
print(aggregation)
use_learnable_weights = False
if args.supervision == 'lomix':
operations = ['add', 'mul', 'wf', 'concat']
use_learnable_weights = True
elif args.supervision == 'mutation':
operations =['add']
else:
operations = []
if use_learnable_weights == True:
learnable = 'learnable_'
else:
learnable = ''
run = 1
#args.exp = args.encoder + '_CASCADE_loss_'+learnable+'relative_softp_' + args.supervision + '_'+str(operations)+'_output_last_layer_Run'+str(run)+'_' + dataset_name + str(args.img_size)+'_nclass_'+str(args.num_classes) #add_sub_multiply_concat #_final_layer
args.exp = args.encoder + '_EMCADNet_kernel_sizes_' + str(args.kernel_sizes) + '_dw_' + dw_mode + '_' + aggregation + '_lgag_ks_' + str(args.lgag_ks) + '_act_mscb_' + args.activation_mscb + '_loss_'+learnable+'relative_softp_' + args.supervision + '_'+str(operations)+'_output_last_layer_Run'+str(run)+'_' + dataset_name + str(args.img_size)#+'_nclass_'+str(args.num_classes) #add_sub_multiply_concat #_final_layer
#snapshot_path = "model_pth/{}/{}".format(args.exp, args.encoder + '_loss_'+learnable+'relative_softp_' + args.supervision + '_'+str(operations)+'_output_last_layer_Run'+str(run)) #add_sub_multiply_concat #_final_layer
snapshot_path = "model_pth/{}/{}".format(args.exp, args.encoder + '_EMCADNet_kernel_sizes_' + str(args.kernel_sizes) + '_dw_' + dw_mode + '_' + aggregation + '_lgag_ks_' + str(args.lgag_ks) + '_act_mscb_' + args.activation_mscb + '_loss_'+learnable+'relative_softp_' + args.supervision + '_'+str(operations)+'_output_last_layer_Run'+str(run)) #add_sub_multiply_concat #_final_layer
snapshot_path = snapshot_path.replace('[', '').replace(']', '').replace(', ', '_')
snapshot_path = snapshot_path + '_pretrain' if not args.no_pretrain else snapshot_path
snapshot_path = snapshot_path+'_'+str(args.max_iterations)[0:2]+'k' if args.max_iterations != 50000 else snapshot_path
snapshot_path = snapshot_path + '_epo' +str(args.max_epochs) if args.max_epochs != 300 else snapshot_path
snapshot_path = snapshot_path+'_bs'+str(args.batch_size)
snapshot_path = snapshot_path + '_lr' + str(args.base_lr) if args.base_lr != 0.0001 else snapshot_path
snapshot_path = snapshot_path + '_'+str(args.img_size)
snapshot_path = snapshot_path + '_s'+str(args.seed) if args.seed!=1234 else snapshot_path
if not os.path.exists(snapshot_path):
os.makedirs(snapshot_path)
print(snapshot_path)
model = EMCADNet(num_classes=args.num_classes, kernel_sizes=args.kernel_sizes, expansion_factor=args.expansion_factor, dw_parallel=not args.no_dw_parallel, add=not args.concatenation, lgag_ks=args.lgag_ks, activation=args.activation_mscb, encoder=args.encoder, pretrain= not args.no_pretrain)
#model = PVT_CASCADE(n_class=args.num_classes, encoder=args.encoder, pretrain=not args.no_pretrain, head='Conv2D', bbox=False, cds=False)
model.cuda()
print('Model successfully created.')
trainer = {'Synapse': trainer_synapse,}
trainer[dataset_name](args, model, snapshot_path, supervision=args.supervision, operations=operations, use_learnable_weights=use_learnable_weights)