-
Notifications
You must be signed in to change notification settings - Fork 9
Expand file tree
/
Copy pathdenoise.py
More file actions
68 lines (52 loc) · 1.71 KB
/
denoise.py
File metadata and controls
68 lines (52 loc) · 1.71 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import os
import imageio
import numpy as np
import torch
import model
import matplotlib.pyplot as plt
# select color or gray-scale
# color = 'gray'
color = 'color'
# define the noise level
sigma = 25
# load a test image
y = imageio.imread(os.path.join('data','water-castle.png')).astype(np.float32)/255
if color == 'gray':
y = np.mean(y, 2, keepdims=True)
# add noise
z = y + sigma/255. * np.random.randn(*y.shape).astype(np.float32)
# load the model state dict
checkpoint = torch.load(os.path.join('checkpoints', f'tdv3-3-25-f32-{color}.pth'))
sigma_ref = 25
# get the variational network with the TDV regularizer
vn = model.VNet(checkpoint['config'], efficient=False)
vn.load_state_dict(checkpoint['model'])
vn.cuda()
# define the evaluation metric
def psnr(x, y):
return 20*np.log10(1.0/np.sqrt(np.mean((x-y) ** 2)))
# define the application of the VN
def apply_vn(x_0, z):
# tranform to reference noise level
scale = sigma_ref/sigma
x = vn(x_0 * scale, z * scale)
# convert back to original scale
x = [j/scale for j in x]
return x
# push the images to torch
y_th = torch.from_numpy(np.transpose(y, (2,0,1))[None]).cuda()
z_th = torch.from_numpy(np.transpose(z, (2,0,1))[None]).cuda()
with torch.no_grad():
x_th = apply_vn(z_th, z_th)
x_S = np.transpose(x_th[-1][0].cpu().numpy(), (1,2,0))
# show the result
fig, ax = plt.subplots(1, 3, sharex=True, sharey=True)
ax[0].imshow(z.squeeze(), vmin=0, vmax=1, cmap='gray')
ax[0].set_title('z')
ax[0].set_xlabel(f'PSNR={psnr(z,y):.2f}dB')
ax[1].imshow(x_S.squeeze(), vmin=0, vmax=1, cmap='gray')
ax[1].set_title('x_S')
ax[1].set_xlabel(f'PSNR={psnr(x_S,y):.2f}dB')
ax[2].imshow(y.squeeze(), vmin=0, vmax=1, cmap='gray')
ax[2].set_title('y')
plt.show()