-
Notifications
You must be signed in to change notification settings - Fork 260
Expand file tree
/
Copy pathsolver.py
More file actions
144 lines (120 loc) · 5.49 KB
/
solver.py
File metadata and controls
144 lines (120 loc) · 5.49 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
from math import ceil
from pathlib import Path
from lightning import LightningModule
from torchmetrics.detection import MeanAveragePrecision
from yolo.config.config import Config
from yolo.model.yolo import create_model
from yolo.tools.data_loader import create_dataloader
from yolo.utils.model_utils import prediction_to_sv
from yolo.tools.drawer import draw_bboxes
from yolo.tools.loss_functions import create_loss_function
from yolo.utils.bounding_box_utils import create_converter, to_metrics_format
from yolo.utils.model_utils import PostProcess, create_optimizer, create_scheduler
class BaseModel(LightningModule):
def __init__(self, cfg: Config):
super().__init__()
self.model = create_model(cfg.model, class_num=cfg.dataset.class_num, weight_path=cfg.weight)
def forward(self, x):
return self.model(x)
class ValidateModel(BaseModel):
def __init__(self, cfg: Config):
super().__init__(cfg)
self.cfg = cfg
if self.cfg.task.task == "validation":
self.validation_cfg = self.cfg.task
else:
self.validation_cfg = self.cfg.task.validation
self.metric = MeanAveragePrecision(iou_type="bbox", box_format="xyxy", backend="faster_coco_eval")
self.metric.warn_on_many_detections = False
self.val_loader = create_dataloader(self.validation_cfg.data, self.cfg.dataset, self.validation_cfg.task)
self.ema = self.model
def setup(self, stage):
self.vec2box = create_converter(
self.cfg.model.name, self.model, self.cfg.model.anchor, self.cfg.image_size, self.device
)
self.post_process = PostProcess(self.vec2box, self.validation_cfg.nms)
def val_dataloader(self):
return self.val_loader
def validation_step(self, batch, batch_idx):
batch_size, images, targets, rev_tensor, img_paths = batch
H, W = images.shape[2:]
predicts = self.post_process(self.ema(images), image_size=[W, H])
self.metric.update(
[to_metrics_format(predict) for predict in predicts], [to_metrics_format(target) for target in targets]
)
return predicts
def on_validation_epoch_end(self):
epoch_metrics = self.metric.compute()
del epoch_metrics["classes"]
self.log_dict(epoch_metrics, prog_bar=True, sync_dist=True, rank_zero_only=True)
self.log_dict(
{"PyCOCO/AP @ .5:.95": epoch_metrics["map"], "PyCOCO/AP @ .5": epoch_metrics["map_50"]},
sync_dist=True,
rank_zero_only=True,
)
self.metric.reset()
class TrainModel(ValidateModel):
def __init__(self, cfg: Config):
super().__init__(cfg)
self.cfg = cfg
self.train_loader = create_dataloader(self.cfg.task.data, self.cfg.dataset, self.cfg.task.task)
def setup(self, stage):
super().setup(stage)
self.loss_fn = create_loss_function(self.cfg, self.vec2box)
def train_dataloader(self):
return self.train_loader
def on_train_epoch_start(self):
self.trainer.optimizers[0].next_epoch(
ceil(len(self.train_loader) / self.trainer.world_size), self.current_epoch
)
self.vec2box.update(self.cfg.image_size)
def training_step(self, batch, batch_idx):
lr_dict = self.trainer.optimizers[0].next_batch()
batch_size, images, targets, *_ = batch
predicts = self(images)
aux_predicts = self.vec2box(predicts["AUX"])
main_predicts = self.vec2box(predicts["Main"])
loss, loss_item = self.loss_fn(aux_predicts, main_predicts, targets)
self.log_dict(
loss_item,
prog_bar=True,
on_epoch=True,
batch_size=batch_size,
rank_zero_only=True,
)
self.log_dict(lr_dict, prog_bar=False, logger=True, on_epoch=False, rank_zero_only=True)
return loss * batch_size
def configure_optimizers(self):
optimizer = create_optimizer(self.model, self.cfg.task.optimizer)
scheduler = create_scheduler(optimizer, self.cfg.task.scheduler)
return [optimizer], [scheduler]
class InferenceModel(BaseModel):
def __init__(self, cfg: Config):
super().__init__(cfg)
self.cfg = cfg
# TODO: Add FastModel
self.predict_loader = create_dataloader(cfg.task.data, cfg.dataset, cfg.task.task)
def setup(self, stage):
self.vec2box = create_converter(
self.cfg.model.name, self.model, self.cfg.model.anchor, self.cfg.image_size, self.device
)
self.post_process = PostProcess(self.vec2box, self.cfg.task.nms)
def predict_dataloader(self):
return self.predict_loader
def predict_step(self, batch, batch_idx):
images, rev_tensor, origin_frame = batch
predicts = self.post_process(self(images), rev_tensor=rev_tensor)
detections = prediction_to_sv(predicts) # convert to sv format
class_list = [str(label) for label in self.cfg.dataset.class_list]
img = draw_bboxes(origin_frame, detections, idx2label=class_list)
if getattr(self.predict_loader, "is_stream", None):
fps = self._display_stream(img)
else:
fps = None
if getattr(self.cfg.task, "save_predict", None):
self._save_image(img, batch_idx)
return img, fps
def _save_image(self, img, batch_idx):
save_image_path = Path(self.trainer.default_root_dir) / f"frame{batch_idx:03d}.png"
img.save(save_image_path)
print(f"💾 Saved visualize image at {save_image_path}")