-
Notifications
You must be signed in to change notification settings - Fork 9
Expand file tree
/
Copy pathmodel.py
More file actions
127 lines (97 loc) · 3.42 KB
/
model.py
File metadata and controls
127 lines (97 loc) · 3.42 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import torch
import torch.utils.checkpoint as cp
import numpy as np
from ddr import TDV
class Dataterm(torch.nn.Module):
"""
Basic dataterm function
"""
def __init__(self, config):
super(Dataterm, self).__init__()
def forward(self, x, *args):
raise NotImplementedError
def energy(self):
raise NotImplementedError
def prox(self, x, *args):
raise NotImplementedError
def grad(self, x, *args):
raise NotImplementedError
class L2DenoiseDataterm(Dataterm):
def __init__(self, config):
super(L2DenoiseDataterm, self).__init__(config)
def energy(self, x, z):
return 0.5*(x-z)**2
def prox(self, x, z, tau):
return (x + tau * z) / (1 + tau)
def grad(self, x, z):
return x-z
class VNet(torch.nn.Module):
"""
Variational Network
"""
def __init__(self, config, efficient=False):
super(VNet, self).__init__()
self.efficient = efficient
self.S = config['S']
# setup the stopping time
if config['T_mode'] == 'fixed':
self.register_buffer('T', torch.tensor(config['T']['init']))
elif config['T_mode'] == 'learned':
self.T = torch.nn.Parameter(torch.Tensor(1))
self.reset_scalar(self.T, **config["T"])
self.T.L_init = 1e+3
else:
raise RuntimeError('T_mode unknown!')
if config['lambda_mode'] == 'fixed':
self.register_buffer('lmbda', torch.tensor(config['lambda']['init']))
elif config['lambda_mode'] == 'learned':
self.lmbda = torch.nn.Parameter(torch.Tensor(1))
self.reset_scalar(self.lmbda, **config["lambda"])
self.lmbda.L_init = 1e+3
else:
raise RuntimeError('lambda_mode unknown!')
# setup the regularization
R_types = {
'tdv': TDV,
}
self.R = R_types[config['R']['type']](config['R']['config'])
# setup the dataterm
self.use_prox = config['D']['config']['use_prox']
D_types = {
'denoise': L2DenoiseDataterm,
}
self.D = D_types[config['D']['type']](config['D']['config'])
def reset_scalar(self, scalar, init=1., min=0, max=1000):
scalar.data = torch.tensor(init, dtype=scalar.dtype)
# add a positivity constraint
scalar.proj = lambda: scalar.data.clamp_(min, max)
def forward(self, x, z, get_grad_R=False):
x_all = x.new_empty((self.S+1,*x.shape))
x_all[0] = x
if get_grad_R:
grad_R_all = x.new_empty((self.S, *x.shape))
# define the step size
tau = self.T / self.S
for s in range(1,self.S+1):
# compute a single step
if self.efficient and x.requires_grad:
grad_R = cp.checkpoint(self.R.grad, x)
else:
grad_R = self.R.grad(x)
if self.use_prox:
x = self.D.prox(x - tau * grad_R, z, self.lmbda / self.S)
else:
x = x - tau * grad_R - self.lmbda/self.S * self.D.grad(x, z)
if get_grad_R:
grad_R_all[s-1] = grad_R
x_all[s] = x
if get_grad_R:
return x_all, grad_R_all
else:
return x_all
def set_end(self, s):
assert 0 < s
self.S = s
def extra_repr(self):
s = "S={S}"
return s.format(**self.__dict__)