-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathsample.py
More file actions
80 lines (69 loc) · 2.72 KB
/
sample.py
File metadata and controls
80 lines (69 loc) · 2.72 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
"""Samples the test data, to asses segmentation quality."""
from typing import Dict
import matplotlib.pyplot as plt
import torch as th
from data_loader import Loader
from train import UNet3D, normalize
from util import softmax_focal_loss
if __name__ == "__main__":
# Change the checkpoint name to match the desired model.
checkpoint_name = "./weights/unet_softmaxfl_124.pkl"
device = th.device("cuda") if th.cuda.is_available() else th.device("cpu")
mean = th.Tensor([206.12558]).to(device)
std = th.Tensor([164.74423]).to(device)
input_shape = [128, 128, 21]
val_keys = ["ProstateX-0004", "ProstateX-0007"]
data_set = Loader(input_shape=input_shape, val_keys=val_keys)
model = UNet3D()
model.load_state_dict(th.load(checkpoint_name))
model = model.to(device)
model.eval()
val_data = data_set.get_val(True)
input_val = normalize(val_data["images"].to(device), mean=mean, std=std)
with th.no_grad():
val_out = model(input_val)
val_out = val_out.permute((0, 3, 4, 2, 1))
label_val = th.nn.functional.one_hot(
val_data["annotation"].type(th.int64), num_classes=5
).to(device)
val_loss = th.mean(
softmax_focal_loss(val_out, label_val, th.ones((val_out.shape[-1])).to(device))
)
print(f"Validation loss: {val_loss:2.6f}")
val_out = val_out.cpu()
def disp_result(
sample: int,
data: Dict[str, th.Tensor],
out: th.Tensor,
name: str,
slice: int = 11,
):
"""Plot the original image, network output and annotation."""
plt.title("scan")
plt.imshow(data["images"].squeeze(1)[sample, :, :, slice])
plt.savefig(f"test_scan_{name}.png")
plt.title("network")
plt.imshow(th.argmax(out[sample, :, :, slice], dim=-1), vmin=0, vmax=5)
plt.savefig(f"test_network_{name}.png")
plt.title("human expert")
plt.imshow(data["annotation"][sample, :, :, slice], vmin=0, vmax=5)
plt.savefig(f"test_expert_{name}.png")
disp_result(0, val_data, val_out, "val_0")
disp_result(1, val_data, val_out, "val_1")
test_data = data_set.get_test_set()
input_test = normalize(test_data["images"].to(device), mean=mean, std=std)
with th.no_grad():
test_out = model(input_test)
test_out = test_out.permute((0, 3, 4, 2, 1))
label_test = th.nn.functional.one_hot(
test_data["annotation"].type(th.int64), num_classes=5
).to(device)
test_loss = th.mean(
softmax_focal_loss(
test_out, label_test, th.ones((test_out.shape[-1])).to(device)
)
)
print(f"Test loss: {test_loss:2.6f}")
test_out = test_out.cpu()
for i in range(20):
disp_result(i, test_data, test_out, f"test_{str(i)}")