Skip to content

Commit 0438411

Browse files
committed
YOLOv4 -> YOLOv7-tiny_Head
1 parent 27cc575 commit 0438411

File tree

1 file changed

+260
-72
lines changed

1 file changed

+260
-72
lines changed

demo_video.py

Lines changed: 260 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,212 @@
22
import cv2
33
import time
44
import math
5+
import copy
56
import argparse
67
import onnxruntime
78
import numpy as np
89
from math import cos, sin
10+
from typing import Tuple, Optional, List
11+
12+
13+
class YOLOv7ONNX(object):
14+
def __init__(
15+
self,
16+
model_path: Optional[str] = 'yolov7_tiny_head_0.768_post_480x640.onnx',
17+
class_score_th: Optional[float] = 0.20,
18+
providers: Optional[List] = [
19+
(
20+
'TensorrtExecutionProvider', {
21+
'trt_engine_cache_enable': True,
22+
'trt_engine_cache_path': '.',
23+
'trt_fp16_enable': True,
24+
}
25+
),
26+
'CUDAExecutionProvider',
27+
'CPUExecutionProvider',
28+
],
29+
):
30+
"""YOLOv7ONNX
31+
Parameters
32+
----------
33+
model_path: Optional[str]
34+
ONNX file path for YOLOv7
35+
class_score_th: Optional[float]
36+
class_score_th: Optional[float]
37+
Score threshold. Default: 0.30
38+
providers: Optional[List]
39+
Name of onnx execution providers
40+
Default:
41+
[
42+
(
43+
'TensorrtExecutionProvider', {
44+
'trt_engine_cache_enable': True,
45+
'trt_engine_cache_path': '.',
46+
'trt_fp16_enable': True,
47+
}
48+
),
49+
'CUDAExecutionProvider',
50+
'CPUExecutionProvider',
51+
]
52+
"""
53+
# Threshold
54+
self.class_score_th = class_score_th
55+
56+
# Model loading
57+
session_option = onnxruntime.SessionOptions()
58+
session_option.log_severity_level = 3
59+
self.onnx_session = onnxruntime.InferenceSession(
60+
model_path,
61+
sess_options=session_option,
62+
providers=providers,
63+
)
64+
self.providers = self.onnx_session.get_providers()
65+
66+
self.input_shapes = [
67+
input.shape for input in self.onnx_session.get_inputs()
68+
]
69+
self.input_names = [
70+
input.name for input in self.onnx_session.get_inputs()
71+
]
72+
self.output_names = [
73+
output.name for output in self.onnx_session.get_outputs()
74+
]
75+
76+
77+
def __call__(
78+
self,
79+
image: np.ndarray,
80+
) -> Tuple[np.ndarray, np.ndarray]:
81+
"""YOLOv7ONNX
82+
Parameters
83+
----------
84+
image: np.ndarray
85+
Entire image
86+
Returns
87+
-------
88+
face_boxes: np.ndarray
89+
Predicted face boxes: [facecount, y1, x1, y2, x2]
90+
face_scores: np.ndarray
91+
Predicted face box scores: [facecount, score]
92+
"""
93+
temp_image = copy.deepcopy(image)
94+
95+
# PreProcess
96+
resized_image = self.__preprocess(
97+
temp_image,
98+
)
99+
100+
# Inference
101+
inferece_image = np.asarray([resized_image], dtype=np.float32)
102+
scores, boxes = self.onnx_session.run(
103+
self.output_names,
104+
{input_name: inferece_image for input_name in self.input_names},
105+
)
106+
107+
# PostProcess
108+
face_boxes, face_scores = self.__postprocess(
109+
image=temp_image,
110+
scores=scores,
111+
boxes=boxes,
112+
)
113+
114+
return face_boxes, face_scores
115+
116+
117+
def __preprocess(
118+
self,
119+
image: np.ndarray,
120+
swap: Optional[Tuple[int,int,int]] = (2, 0, 1),
121+
) -> np.ndarray:
122+
"""__preprocess
123+
Parameters
124+
----------
125+
image: np.ndarray
126+
Entire image
127+
swap: tuple
128+
HWC to CHW: (2,0,1)
129+
CHW to HWC: (1,2,0)
130+
HWC to HWC: (0,1,2)
131+
CHW to CHW: (0,1,2)
132+
Returns
133+
-------
134+
resized_image: np.ndarray
135+
Resized and normalized image.
136+
"""
137+
# Normalization + BGR->RGB
138+
resized_image = cv2.resize(
139+
image,
140+
(
141+
int(self.input_shapes[0][3]),
142+
int(self.input_shapes[0][2]),
143+
)
144+
)
145+
resized_image = np.divide(resized_image, 255.0)
146+
resized_image = resized_image[..., ::-1]
147+
resized_image = resized_image.transpose(swap)
148+
resized_image = np.ascontiguousarray(
149+
resized_image,
150+
dtype=np.float32,
151+
)
152+
return resized_image
153+
154+
155+
def __postprocess(
156+
self,
157+
image: np.ndarray,
158+
scores: np.ndarray,
159+
boxes: np.ndarray,
160+
) -> Tuple[np.ndarray, np.ndarray]:
161+
"""__postprocess
162+
Parameters
163+
----------
164+
image: np.ndarray
165+
Entire image.
166+
scores: np.ndarray
167+
float32[N, 1]
168+
boxes: np.ndarray
169+
int64[N, 6]
170+
Returns
171+
-------
172+
faceboxes: np.ndarray
173+
Predicted face boxes: [facecount, y1, x1, y2, x2]
174+
facescores: np.ndarray
175+
Predicted face box confs: [facecount, score]
176+
"""
177+
image_height = image.shape[0]
178+
image_width = image.shape[1]
179+
180+
"""
181+
Head Detector is
182+
N -> Number of boxes detected
183+
batchno -> always 0: BatchNo.0
184+
classid -> always 0: "Head"
185+
scores: float32[N,1],
186+
batchno_classid_y1x1y2x2: int64[N,6],
187+
"""
188+
scores = scores
189+
keep_idxs = scores[:, 0] > self.class_score_th
190+
scores_keep = scores[keep_idxs, :]
191+
boxes_keep = boxes[keep_idxs, :]
192+
faceboxes = []
193+
facescores = []
194+
195+
if len(boxes_keep) > 0:
196+
for box, score in zip(boxes_keep, scores_keep):
197+
x_min = max(int(box[3]), 0)
198+
y_min = max(int(box[2]), 0)
199+
x_max = min(int(box[5]), image_width)
200+
y_max = min(int(box[4]), image_height)
201+
202+
faceboxes.append(
203+
[x_min, y_min, x_max, y_max]
204+
)
205+
facescores.append(
206+
score
207+
)
208+
209+
return np.asarray(faceboxes), np.asarray(facescores)
210+
9211

10212

11213
def draw_axis(img, yaw, pitch, roll, tdx=None, tdy=None, size=100):
@@ -39,24 +241,10 @@ def draw_axis(img, yaw, pitch, roll, tdx=None, tdy=None, size=100):
39241

40242

41243
def main(args):
42-
# YOLOv4-Head
43-
yolov4_head = onnxruntime.InferenceSession(
44-
path_or_bytes=f'yolov4_headdetection_480x640_post.onnx',
45-
providers=[
46-
(
47-
'TensorrtExecutionProvider', {
48-
'trt_engine_cache_enable': True,
49-
'trt_engine_cache_path': '.',
50-
'trt_fp16_enable': True,
51-
}
52-
),
53-
'CUDAExecutionProvider',
54-
'CPUExecutionProvider',
55-
]
244+
# YOLOv7_tiny_Head
245+
yolov7_head = YOLOv7ONNX(
246+
class_score_th=0.20,
56247
)
57-
yolov4_head_input_name = yolov4_head.get_inputs()[0].name
58-
yolov4_head_H = yolov4_head.get_inputs()[0].shape[2]
59-
yolov4_head_W = yolov4_head.get_inputs()[0].shape[3]
60248

61249
# DMHead
62250
model_file_path = ''
@@ -98,44 +286,34 @@ def main(args):
98286
cv2.namedWindow(WINDOWS_NAME, cv2.WINDOW_NORMAL)
99287
cv2.resizeWindow(WINDOWS_NAME, cap_width, cap_height)
100288

101-
while True:
102-
start = time.time()
289+
cap_fps = cap.get(cv2.CAP_PROP_FPS)
290+
w = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
291+
h = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
292+
fourcc = cv2.VideoWriter_fourcc('m', 'p', '4', 'v')
293+
video_writer = cv2.VideoWriter(
294+
filename='output.mp4',
295+
fourcc=fourcc,
296+
fps=cap_fps,
297+
frameSize=(w, h),
298+
)
103299

300+
while True:
104301
ret, frame = cap.read()
105302
if not ret:
106-
continue
107-
108-
# ============================================================= YOLOv4
109-
# Resize
110-
resized_frame = cv2.resize(frame, (yolov4_head_W, yolov4_head_H))
111-
# BGR to RGB
112-
rgb = resized_frame[..., ::-1]
113-
# HWC -> CHW
114-
chw = rgb.transpose(2, 0, 1)
115-
# normalize to [0, 1] interval
116-
chw = np.asarray(chw / 255., dtype=np.float32)
117-
# hwc --> nhwc
118-
nchw = chw[np.newaxis, ...]
119-
# Inference YOLOv4
120-
heads = yolov4_head.run(
121-
None,
122-
input_feed = {yolov4_head_input_name: nchw}
123-
)[0]
124-
125-
canvas = resized_frame.copy()
303+
break
304+
305+
start = time.time()
306+
307+
# ============================================================= YOLOv7_tiny_Head
308+
heads, head_scores = yolov7_head(frame)
309+
310+
canvas = copy.deepcopy(frame)
126311
# ============================================================= DMHead
127312
croped_resized_frame = None
128-
scores = heads[:,4]
129-
keep_idxs = scores > 0.6
130-
heads = heads[keep_idxs, :]
131313

132314
if len(heads) > 0:
133315
dmhead_inputs = []
134316
dmhead_position = []
135-
heads[:, 0] = heads[:, 0] * cap_width
136-
heads[:, 1] = heads[:, 1] * cap_height
137-
heads[:, 2] = heads[:, 2] * cap_width
138-
heads[:, 3] = heads[:, 3] * cap_height
139317

140318
for head in heads:
141319
x_min = int(head[0])
@@ -145,11 +323,11 @@ def main(args):
145323

146324
# enlarge the bbox to include more background margin
147325
y_min = max(0, y_min - abs(y_min - y_max) / 10)
148-
y_max = min(resized_frame.shape[0], y_max + abs(y_min - y_max) / 10)
326+
y_max = min(frame.shape[0], y_max + abs(y_min - y_max) / 10)
149327
x_min = max(0, x_min - abs(x_min - x_max) / 5)
150-
x_max = min(resized_frame.shape[1], x_max + abs(x_min - x_max) / 5)
151-
x_max = min(x_max, resized_frame.shape[1])
152-
croped_frame = resized_frame[int(y_min):int(y_max), int(x_min):int(x_max)]
328+
x_max = min(frame.shape[1], x_max + abs(x_min - x_max) / 5)
329+
x_max = min(x_max, frame.shape[1])
330+
croped_frame = frame[int(y_min):int(y_max), int(x_min):int(x_max)]
153331

154332
# h,w -> 224,224
155333
croped_resized_frame = cv2.resize(croped_frame, (dmhead_W, dmhead_H))
@@ -227,35 +405,43 @@ def main(args):
227405
1
228406
)
229407

230-
time_txt = f'{(time.time()-start)*1000:.2f} ms'
231-
cv2.putText(
232-
canvas,
233-
time_txt,
234-
(20, 35),
235-
cv2.FONT_HERSHEY_SIMPLEX,
236-
1,
237-
(255, 255, 255),
238-
2,
239-
cv2.LINE_AA,
240-
)
241-
cv2.putText(
242-
canvas,
243-
time_txt,
244-
(20, 35),
245-
cv2.FONT_HERSHEY_SIMPLEX,
246-
1,
247-
(0, 0, 255),
248-
1,
249-
cv2.LINE_AA,
250-
)
408+
time_txt = f'{(time.time()-start)*1000:.2f} ms (inference+post-process)'
409+
cv2.putText(
410+
canvas,
411+
time_txt,
412+
(20, 35),
413+
cv2.FONT_HERSHEY_SIMPLEX,
414+
0.8,
415+
(255, 255, 255),
416+
2,
417+
cv2.LINE_AA,
418+
)
419+
cv2.putText(
420+
canvas,
421+
time_txt,
422+
(20, 35),
423+
cv2.FONT_HERSHEY_SIMPLEX,
424+
0.8,
425+
(0, 255, 0),
426+
1,
427+
cv2.LINE_AA,
428+
)
251429

252430
key = cv2.waitKey(1)
253431
if key == 27: # ESC
254432
break
255433

256434
cv2.imshow(WINDOWS_NAME, canvas)
435+
video_writer.write(canvas)
436+
257437
cv2.destroyAllWindows()
258438

439+
if video_writer:
440+
video_writer.release()
441+
442+
if cap:
443+
cap.release()
444+
259445
if __name__ == "__main__":
260446
parser = argparse.ArgumentParser()
261447
parser.add_argument(
@@ -278,7 +464,9 @@ def main(args):
278464
'mask',
279465
'nomask',
280466
],
281-
help='Select either a model that provides high accuracy when wearing a mask or a model that provides high accuracy when not wearing a mask.',
467+
help='\
468+
Select either a model that provides high accuracy when wearing \
469+
a mask or a model that provides high accuracy when not wearing a mask.',
282470
)
283471
args = parser.parse_args()
284472
main(args)

0 commit comments

Comments
 (0)