-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathsummary.py
More file actions
69 lines (60 loc) · 3.3 KB
/
summary.py
File metadata and controls
69 lines (60 loc) · 3.3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
_author__ = 'MSteger'
import torch
import numpy as np
class summary(object):
def __init__(self, model, device = torch.device('cpu'), input_size=(1,1,256,256), verbose = True):
self.model = model.to(device)
self.input_size = input_size
self.device = device
self.iterate()
if verbose: self.printer()
def compute_output(self, input, layer):
if isinstance(layer, torch.nn.Linear):
try:
input = input.resize(1, layer.in_features)
except Exception as e:
print 'Failure! {} >> using hack'.format(e)
input = input.mean(-1).mean(-1) # ...
return layer(input)
def compute_no_params(self, layer):
params_to_optim, params_frozen = [], []
for p in layer.parameters():
if p.requires_grad:
params_to_optim.append(tuple(p.size()))
else:
params_frozen.append(tuple(p.size()))
return [np.sum([np.prod(p) for p in params_to_optim]).astype(int), np.sum([np.prod(p) for p in params_frozen]).astype(int)]
def iterate(self):
summary = []
with torch.no_grad():
input = torch.autograd.Variable(torch.FloatTensor(*self.input_size)).to(self.device)
for k, v in self.model._modules.iteritems():
if isinstance(v, torch.nn.Sequential):
for layer in v:
output = self.compute_output(input, layer)
summary.append([k, type(layer).__name__, tuple(input.shape)[1:], tuple(output.shape)[1:]] + self.compute_no_params(layer))
input = output
else:
layer = v
output = self.compute_output(input, layer)
summary.append([k, type(layer).__name__, tuple(input.shape)[1:], tuple(output.shape)[1:]] + self.compute_no_params(layer))
input = output
self.summary = summary
return summary
def printer(self, summary = None):
if summary is None: summary = self.summary
total_params, trainable_params = 0, 0
print 'Model Summary'
print '---------------------------------------------------------------------------------------------------------------------------------'
print '{:>2} {:>20} {:>20} {:>20} {:>20} {:>20} {:>20}'.format('Id', 'Name', 'Type', 'Input', 'Output', 'Params', 'Params(Frozen)')
print '---------------------------------------------------------------------------------------------------------------------------------'
for idx, layer in enumerate(summary):
print '{:>2} {:>20} {:>20} {:>20} {:>20} {:>20} {:>20}'.format(*[idx+1]+layer)
total_params += layer[-2] + layer[-1]
trainable_params += layer[-2]
print '================================================================================================================================='
print 'Total params: {0:,}'.format(total_params)
print 'Trainable params: {0:,}'.format(trainable_params)
print 'Non-trainable params: {0:,}'.format(total_params - trainable_params)
print '---------------------------------------------------------------------------------------------------------------------------------'
return self