-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmodel.py
More file actions
74 lines (57 loc) · 2.63 KB
/
model.py
File metadata and controls
74 lines (57 loc) · 2.63 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
import torch
import torchvision
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
class DETRModel(torch.nn.Module):
def __init__(self,num_classes,num_queries):
super(DETRModel,self).__init__()
self.num_classes = num_classes
self.num_queries = num_queries
self.model = torch.hub.load('facebookresearch/detr', 'detr_resnet50', pretrained=True)
self.in_features = self.model.class_embed.in_features
self.model.class_embed = torch.nn.Linear(in_features=self.in_features,out_features=self.num_classes)
self.model.num_queries = self.num_queries
def forward(self,images):
return self.model(images)
def box_cxcywh_to_xyxy(x):
x_c, y_c, w, h = x.unbind(1)
b = [(x_c - 0.5 * w), (y_c - 0.5 * h),
(x_c + 0.5 * w), (y_c + 0.5 * h)]
return torch.stack(b, dim=1)
def rescale_bboxes(out_bbox, size):
img_w, img_h = size
b = box_cxcywh_to_xyxy(out_bbox)
b = b * torch.tensor([img_w, img_h, img_w, img_h], dtype=torch.float32, device="cuda")
return b
def detect(im, model, device, transform=None):
# mean-std normalize the input image (batch-size: 1)
if transform != None:
img = transform(im).unsqueeze(0)
else:
img = im.unsqueeze(0)
# demo model only support by default images with aspect ratio between 0.5 and 2
# if you want to use images with an aspect ratio outside this range
# rescale your image so that the maximum size is at most 1333 for best results
# assert img.shape[-2] <= 1600 and img.shape[-1] <= 1600, 'demo model only supports images up to 1600 pixels on each side'
# propagate through the model
img = img.to(device)
outputs = model(img)
# keep only predictions with 0.7+ confidence
probas = outputs['pred_logits'].softmax(-1)[0, :, :-1]
keep = probas.max(-1).values > 0.8
# convert boxes from [0; 1] to image scales
bboxes_scaled = rescale_bboxes(outputs['pred_boxes'][0, keep], im.size)
return probas[keep], bboxes_scaled
def get_fasterrcnn(num_classes):
# Load an pre-trained object detectin model (in this case faster-rcnn)
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained = True)
# Number of input features
in_features = model.roi_heads.box_predictor.cls_score.in_features
# Replace the pre-trained head with a new head
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes)
return model
def detect_frcnn(model, images):
outputs = model([images])
probas = outputs[0]['scores']
keep = probas > 0.8
labels = outputs[0]['labels'][keep]
return probas[keep], labels