-
Notifications
You must be signed in to change notification settings - Fork 6
/
detect_main.py
115 lines (98 loc) · 4.22 KB
/
detect_main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import cv2
import os
import time
import torch
import argparse
from nanodet.util import cfg, load_config, Logger
from nanodet.model.arch import build_model
from nanodet.util import load_model_weight
from nanodet.data.transform import Pipeline
image_ext = ['.jpg', '.jpeg', '.webp', '.bmp', '.png']
video_ext = ['mp4', 'mov', 'avi', 'mkv']
'''目标检测-图片'''
# python detect_main.py image --config ./config/nanodet-m.yml --model model/nanodet_m.pth --path street.png
'''目标检测-视频文件'''
# python detect_main.py video --config ./config/nanodet-m.yml --model model/nanodet_m.pth --path test.mp4
'''目标检测-摄像头'''
# python detect_main.py webcam --config ./config/nanodet-m.yml --model model/nanodet_m.pth --path 0
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument('demo', default='image', help='demo type, eg. image, video and webcam')
parser.add_argument('--config', help='model config file path')
parser.add_argument('--model', help='model file path')
parser.add_argument('--path', default='./demo', help='path to images or video')
parser.add_argument('--camid', type=int, default=0, help='webcam demo camera id')
args = parser.parse_args()
return args
class Predictor(object):
def __init__(self, cfg, model_path, logger, device='cuda:0'):
self.cfg = cfg
self.device = device
model = build_model(cfg.model)
ckpt = torch.load(model_path, map_location=lambda storage, loc: storage)
load_model_weight(model, ckpt, logger)
self.model = model.to(device).eval()
self.pipeline = Pipeline(cfg.data.val.pipeline, cfg.data.val.keep_ratio)
def inference(self, img):
img_info = {}
if isinstance(img, str):
img_info['file_name'] = os.path.basename(img)
img = cv2.imread(img)
else:
img_info['file_name'] = None
height, width = img.shape[:2]
img_info['height'] = height
img_info['width'] = width
meta = dict(img_info=img_info,
raw_img=img,
img=img)
meta = self.pipeline(meta, self.cfg.data.val.input_size)
meta['img'] = torch.from_numpy(meta['img'].transpose(2, 0, 1)).unsqueeze(0).to(self.device)
with torch.no_grad():
results = self.model.inference(meta)
return meta, results
def visualize(self, dets, meta, class_names, score_thres, wait=0):
time1 = time.time()
self.model.head.show_result(meta['raw_img'], dets, class_names, score_thres=score_thres, show=True)
print('viz time: {:.3f}s'.format(time.time()-time1))
def get_image_list(path):
image_names = []
for maindir, subdir, file_name_list in os.walk(path):
for filename in file_name_list:
apath = os.path.join(maindir, filename)
ext = os.path.splitext(apath)[1]
if ext in image_ext:
image_names.append(apath)
return image_names
def main():
args = parse_args()
torch.backends.cudnn.enabled = True
torch.backends.cudnn.benchmark = True
load_config(cfg, args.config)
logger = Logger(-1, use_tensorboard=False)
# predictor = Predictor(cfg, args.model, logger , device='cuda:0') # "cpu"
predictor = Predictor(cfg, args.model, logger, device='cpu')
logger.log('Press "Esc", "q" or "Q" to exit.')
if args.demo == 'image':
if os.path.isdir(args.path):
files = get_image_list(args.path)
else:
files = [args.path]
files.sort()
for image_name in files:
meta, res = predictor.inference(image_name)
predictor.visualize(res, meta, cfg.class_names, 0.35)
ch = cv2.waitKey(0)
if ch == 27 or ch == ord('q') or ch == ord('Q'):
break
elif args.demo == 'video' or args.demo == 'webcam':
cap = cv2.VideoCapture(args.path if args.demo == 'video' else args.camid)
while True:
ret_val, frame = cap.read()
meta, res = predictor.inference(frame)
predictor.visualize(res, meta, cfg.class_names, 0.35)
ch = cv2.waitKey(1)
if ch == 27 or ch == ord('q') or ch == ord('Q'):
break
if __name__ == '__main__':
main()