-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmodel.py
More file actions
175 lines (143 loc) · 5.72 KB
/
model.py
File metadata and controls
175 lines (143 loc) · 5.72 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
'''
@file: run.py
This file contains all the models used in the experiments.
@author: Rukmangadh Sai Myana
@mail: rukman.sai@gmail.com
'''
import torch
import torch.nn as nn
_LAYER_MAP = {
'conv2d': nn.Conv2d,
'relu': nn.ReLU,
'softmax': nn.Softmax,
'linear': nn.Linear,
'batch_norm1d': nn.BatchNorm1d,
'batch_norm2d': nn.BatchNorm2d,
'dropout': nn.Dropout,
'dropout2d': nn.Dropout2d,
}
class BasicClassificationCNN(nn.Module):
'''
This class represents a basic CNN model that is used for the classfication
of the MNIST Dataset.
'''
def __init__(self, network_definition):
'''
Initilize the model.
@param network_definition: The definition of the model's computational
network.
'''
# initalize the parent class
super(BasicClassificationCNN, self).__init__()
self.conv_network = []
self.fc_network = []
# define the model.
on_conv_network = True
for layer_definition in network_definition:
layer_type = layer_definition.pop('layer_type')
layer_name = layer_definition.pop('layer_name')
if layer_type == 'flatten':
on_conv_network = False
continue # continue to next layer definition
# ModuleDict accepts an iterable of type (string, Module)
if on_conv_network:
self.conv_network.append(
(layer_name, _LAYER_MAP[layer_type](**layer_definition)))
else:
self.fc_network.append(
(layer_name, _LAYER_MAP[layer_type](**layer_definition)))
# convert to neural network modules
self.conv_network = nn.ModuleDict(self.conv_network)
self.fc_network = nn.ModuleDict(self.fc_network)
def forward(self, x):
'''
Define the forward pass for the model.
@param x: input
@returns: output
'''
for key in self.conv_network.keys():
x = self.conv_network[key](x)
x = torch.flatten(x, 1)
for key in self.fc_network.keys():
x = self.fc_network[key](x)
return x
class ModularClassificationCNN(nn.Module):
'''
This class represents the block modular neural network architecture we are
interested in developing.
'''
def __init__(self, network_definition):
'''
Initialize the model.
@param network_definition: The definition of the model's computational
network.
'''
# initialize the parent class
super(ModularClassificationCNN, self).__init__()
self.conv_network = [] # pre-trained neural network conv layers
self.fc_network = [] # pre-trained neural network fc layers
self.bnn_conv_network = [] # block neural network conv layers
self.bnn_fc_network = [] # block neural network fc layers
# define the pre-trained neural network
on_conv_network = True
for layer_definition in network_definition['PNN']:
layer_type = layer_definition.pop('layer_type')
layer_name = layer_definition.pop('layer_name')
if layer_type == 'flatten':
on_conv_network = False
continue # continue to next layer definition
# ModuleDict accepts an iterable of type (string, Module)
if on_conv_network:
self.conv_network.append(
(layer_name, _LAYER_MAP[layer_type](**layer_definition)))
else:
self.fc_network.append(
(layer_name, _LAYER_MAP[layer_type](**layer_definition)))
# convert to neural network modules
self.conv_network = nn.ModuleDict(self.conv_network)
self.fc_network = nn.ModuleDict(self.fc_network)
# define the block neural network
on_conv_network = True
for layer_definition in network_definition['BNN']:
layer_type = layer_definition.pop('layer_type')
layer_name = layer_definition.pop('layer_name')
if layer_type == 'flatten':
on_conv_network = False
continue # continue to next layer definition
# ModuleDict accepts an iterable of type (string, Module)
if on_conv_network:
self.bnn_conv_network.append(
(layer_name, _LAYER_MAP[layer_type](**layer_definition)))
else:
self.bnn_fc_network.append(
(layer_name, _LAYER_MAP[layer_type](**layer_definition)))
# convert to neural network modules
self.bnn_conv_network = nn.ModuleDict(self.bnn_conv_network)
self.bnn_fc_network = nn.ModuleDict(self.bnn_fc_network)
def forward(self, x):
'''
Define the forward pass for the model.
@param x: input
@returns: outputs
'''
x = x
bnn_x = x
first_conv = True # first convolution flag
# propagate through convolutional layers
for key in self.conv_network.keys():
bnn_key = 'bnn_' + key
if (not first_conv) and ('conv' in bnn_key):
bnn_x = torch.cat((x, bnn_x), dim=1) # knowledge transfer
x = self.conv_network[key](x)
bnn_x = self.bnn_conv_network[bnn_key](bnn_x)
first_conv = False
x = torch.flatten(x, 1)
bnn_x = torch.flatten(bnn_x, 1)
# propagate through fully connected layers
for key in self.fc_network:
bnn_key = 'bnn_' + key
if 'linear' in bnn_key:
bnn_x = torch.cat((x, bnn_x), dim=1) # knowledge transfer
x = self.fc_network[key](x)
bnn_x = self.bnn_fc_network[bnn_key](bnn_x)
return bnn_x