This repository was archived by the owner on Nov 13, 2025. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 37
Expand file tree
/
Copy pathdetect_target_ultralytics.py
More file actions
151 lines (123 loc) · 4.81 KB
/
detect_target_ultralytics.py
File metadata and controls
151 lines (123 loc) · 4.81 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
"""
Detects objects using the provided model.
"""
import time
import cv2
import torch
import ultralytics
from . import base_detect_target
from .. import image_and_time
from .. import detections_and_time
from ..common.modules.logger import logger
class DetectTargetUltralyticsConfig:
"""
Configuration for DetectTargetUltralytics.
"""
CPU_DEVICE = "cpu"
def __init__(
self,
device: "str | int",
model_path: str,
override_full: bool,
) -> None:
"""
Initializes the configuration for DetectTargetUltralytics.
device: name of target device to run inference on (i.e. "cpu" or cuda device 0, 1, 2, 3).
model_path: path to the YOLOv8 model.
override_full: Force full precision floating point calculations.
"""
self.device = device
self.model_path = model_path
self.override_full = override_full
class DetectTargetUltralytics(base_detect_target.BaseDetectTarget):
"""
Contains the YOLOv8 model for prediction.
"""
def __init__(
self,
config: DetectTargetUltralyticsConfig,
local_logger: logger.Logger,
show_annotations: bool = False,
save_name: str = "",
) -> None:
"""
device: name of target device to run inference on (i.e. "cpu" or cuda device 0, 1, 2, 3).
model_path: path to the YOLOv8 model.
override_full: Force full precision floating point calculations.
show_annotations: Display annotated images.
save_name: filename prefix for logging detections and annotated images.
"""
self.__local_logger = local_logger
self.__device = config.device
if (
self.__device != DetectTargetUltralyticsConfig.CPU_DEVICE
and not torch.cuda.is_available()
):
self.__local_logger.warning("CUDA not available. Falling back to CPU.")
self.__device = DetectTargetUltralyticsConfig.CPU_DEVICE
self.__enable_half_precision = self.__device != DetectTargetUltralyticsConfig.CPU_DEVICE
self.__model = ultralytics.YOLO(config.model_path)
if config.override_full:
self.__enable_half_precision = False
self.__counter = 0
self.__show_annotations = show_annotations
self.__filename_prefix = ""
if save_name != "":
self.__filename_prefix = save_name + "_" + str(int(time.time())) + "_"
def run(
self, data: image_and_time.ImageAndTime
) -> "tuple[bool, detections_and_time.DetectionsAndTime | None]":
"""
Runs object detection on the provided image and returns the detections.
data: Image with a timestamp.
Return: Success and the detections.
"""
image = data.image
start_time = time.time()
predictions = self.__model.predict(
source=image,
half=self.__enable_half_precision,
device=self.__device,
stream=False,
)
if len(predictions) == 0:
return False, None
image_annotated = predictions[0].plot(conf=True)
# Processing object detection
boxes = predictions[0].boxes
if boxes.shape[0] == 0:
return False, None
# Make a copy of bounding boxes in CPU space
objects_bounds = boxes.xyxy.detach().cpu().numpy()
result, detections = detections_and_time.DetectionsAndTime.create(data.timestamp)
if not result:
return False, None
# Get Pylance to stop complaining
assert detections is not None
for i in range(0, boxes.shape[0]):
bounds = objects_bounds[i]
label = int(boxes.cls[i])
confidence = float(boxes.conf[i])
result, detection = detections_and_time.Detection.create(bounds, label, confidence)
if result:
# Get Pylance to stop complaining
assert detection is not None
detections.append(detection)
end_time = time.time()
self.__local_logger.info(
f"{time.time()}: Count: {self.__counter}. Target detection took {end_time - start_time} seconds. Objects detected: {detections}."
)
# Logging
if self.__filename_prefix != "":
filename = self.__filename_prefix + str(self.__counter)
# Annotated image
cv2.imwrite(filename + ".png", image_annotated)
self.__counter += 1
if self.__show_annotations:
if image_annotated is not None:
# Display the annotated image in a named window
cv2.imshow("Annotated", image_annotated)
cv2.waitKey(1) # Short delay to process GUI events
else:
self.__local_logger.warning("Annotated image is invalid.")
return True, detections