forked from hackzoho/Data-Structures-and-Algorithms
-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathfaster_cnn.py
More file actions
85 lines (69 loc) · 3.05 KB
/
faster_cnn.py
File metadata and controls
85 lines (69 loc) · 3.05 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import os.path
import cv2
import numpy as np
import requests
import torchvision
import torchvision.transforms as transforms
print("Faster R-CNN object detection")
# COCO dataset class names
classes = [
'background', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'street sign',
'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse',
'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'hat', 'backpack',
'umbrella', 'shoe', 'eye glasses', 'handbag', 'tie', 'suitcase', 'frisbee',
'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove',
'skateboard', 'surfboard', 'tennis racket', 'bottle', 'plate', 'wine glass',
'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich',
'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake', 'chair',
'couch', 'potted plant', 'bed', 'mirror', 'dining table', 'window', 'desk',
'toilet', 'door', 'tv', 'laptop', 'mouse', 'remote', 'keyboard', 'cell phone',
'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'blender', 'book',
'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush', 'hair brush']
# Download object detection image
image_file = 'source_2.png'
if not os.path.isfile(image_file):
url = "https://github.com/ivan-vasilev/advanced-deep-learning-with-python/blob/master/chapter04-detection-segmentation/source_2.png"
r = requests.get(url)
with open(image_file, 'wb') as f:
f.write(r.content)
# load the pytorch model
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
# set the model in evaluation mode
model.eval()
# read the image file
img = cv2.imread(image_file)
# transform the input to tensor
transform = transforms.Compose([transforms.ToPILImage(), transforms.ToTensor()])
nn_input = transform(img)
output = model([nn_input])
# random color for each class
colors = np.random.uniform(0, 255, size=(len(classes), 3))
# iterate over the network output for all boxes
for box, box_class, score in zip(output[0]['boxes'].detach().numpy(),
output[0]['labels'].detach().numpy(),
output[0]['scores'].detach().numpy()):
# filter the boxes by score
if score > 0.5:
# transform bounding box format
box = [(box[0], box[1]), (box[2], box[3])]
# select class color
color = colors[box_class]
# extract class name
class_name = classes[box_class]
# draw the bounding box
cv2.rectangle(img=img,
pt1=box[0],
pt2=box[1],
color=color,
thickness=2)
# display the box class label
cv2.putText(img=img,
text=class_name,
org=box[0],
fontFace=cv2.FONT_HERSHEY_SIMPLEX,
fontScale=1,
color=color,
thickness=2)
cv2.imshow("Object detection", img)
cv2.waitKey()