-
Notifications
You must be signed in to change notification settings - Fork 146
Expand file tree
/
Copy pathinteractive_predictor.py
More file actions
139 lines (124 loc) · 5.36 KB
/
interactive_predictor.py
File metadata and controls
139 lines (124 loc) · 5.36 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import torch
import numpy as np
from torchvision import transforms
from utils.visualizer import Visualizer
from typing import Tuple
from PIL import Image
from detectron2.data import MetadataCatalog
metadata = MetadataCatalog.get('coco_2017_train_panoptic')
class SemanticSAMPredictor:
def __init__(self, model, thresh=0.5, text_size=640, hole_scale=100, island_scale=100):
"""
thresh: iou thresh to filter low confidence objects
text_size: resize the input image short edge for the model to process
hole_scale: fill in small holes as in SAM
island_scale: remove small regions as in SAM
"""
self.model = model
self.thresh = thresh
self.text_size = hole_scale
self.hole_scale = hole_scale
self.island_scale = island_scale
self.point = None
def predict(self, image_ori, image, point=None):
"""
produce up to 6 prediction results for each click
"""
width = image_ori.shape[1]
height = image_ori.shape[0]
data = {"image": image, "height": height, "width": width}
# import ipdb; ipdb.set_trace()
if point is None:
point = torch.tensor([[0.5, 0.5, 0.006, 0.006]]).cuda()
else:
point = torch.tensor(point).cuda()
point_ = point
point = point_.clone()
point[0, 0] = point_[0, 0]
point[0, 1] = point_[0, 1]
# point = point[:, [1, 0]]
point = torch.cat([point, point.new_tensor([[0.005, 0.005]])], dim=-1)
self.point = point[:, :2].clone()*(torch.tensor([width, height]).to(point))
data['targets'] = [dict()]
data['targets'][0]['points'] = point
data['targets'][0]['pb'] = point.new_tensor([0.])
batch_inputs = [data]
masks, ious = self.model.model.evaluate_demo(batch_inputs)
return masks, ious
def process_multi_mask(self, masks, ious, image_ori):
pred_masks_poses = masks
reses = []
ious = ious[0, 0]
ids = torch.argsort(ious, descending=True)
text_res = ''
mask_ls = []
ious_res = []
areas = []
for i, (pred_masks_pos, iou) in enumerate(zip(pred_masks_poses[ids], ious[ids])):
iou = round(float(iou), 2)
texts = f'{iou}'
mask = (pred_masks_pos > 0.0).cpu().numpy()
area = mask.sum()
conti = False
if iou < self.thresh:
conti = True
for m in mask_ls:
if np.logical_and(mask, m).sum() / np.logical_or(mask, m).sum() > 0.95:
conti = True
break
if i == len(pred_masks_poses[ids]) - 1 and mask_ls == []:
conti = False
if conti:
continue
ious_res.append(iou)
mask_ls.append(mask)
areas.append(area)
mask, _ = self.remove_small_regions(mask, int(self.hole_scale), mode="holes")
mask, _ = self.remove_small_regions(mask, int(self.island_scale), mode="islands")
mask = (mask).astype(float)
out_txt = texts
visual = Visualizer(image_ori, metadata=metadata)
color = [0., 0., 1.0]
demo = visual.draw_binary_mask(mask, color=color, text=texts)
res = demo.get_image()
point_x0 = max(0, int(self.point[0, 0]) - 3)
point_x1 = min(image_ori.shape[1], int(self.point[0, 0]) + 3)
point_y0 = max(0, int(self.point[0, 1]) - 3)
point_y1 = min(image_ori.shape[0], int(self.point[0, 1]) + 3)
res[point_y0:point_y1, point_x0:point_x1, 0] = 255
res[point_y0:point_y1, point_x0:point_x1, 1] = 0
res[point_y0:point_y1, point_x0:point_x1, 2] = 0
reses.append(Image.fromarray(res))
text_res = text_res + ';' + out_txt
ids = list(torch.argsort(torch.tensor(areas), descending=False))
ids = [int(i) for i in ids]
torch.cuda.empty_cache()
return reses, [reses[i] for i in ids]
def predict_masks(self, image_ori, image, point=None):
masks, ious = self.predict(image_ori, image, point)
return self.process_multi_mask(masks, ious, image_ori)
@staticmethod
def remove_small_regions(
mask: np.ndarray, area_thresh: float, mode: str
) -> Tuple[np.ndarray, bool]:
"""
Removes small disconnected regions and holes in a mask. Returns the
mask and an indicator of if the mask has been modified.
"""
import cv2 # type: ignore
assert mode in ["holes", "islands"]
correct_holes = mode == "holes"
working_mask = (correct_holes ^ mask).astype(np.uint8)
n_labels, regions, stats, _ = cv2.connectedComponentsWithStats(working_mask, 8)
sizes = stats[:, -1][1:] # Row 0 is background label
small_regions = [i + 1 for i, s in enumerate(sizes) if s < area_thresh]
if len(small_regions) == 0:
return mask, False
fill_labels = [0] + small_regions
if not correct_holes:
fill_labels = [i for i in range(n_labels) if i not in fill_labels]
# If every region is below threshold, keep largest
if len(fill_labels) == 0:
fill_labels = [int(np.argmax(sizes)) + 1]
mask = np.isin(regions, fill_labels)
return mask, True