Skip to content

Commit

Permalink
support coco-wholebody visualization in pose_tracker python demo (#2450)
Browse files Browse the repository at this point in the history
* update

* update
  • Loading branch information
RunningLeon authored Sep 21, 2023
1 parent 062abd9 commit 01a88be
Showing 1 changed file with 98 additions and 23 deletions.
121 changes: 98 additions & 23 deletions demo/python/pose_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import os

import cv2
import numpy as np
from mmdeploy_runtime import PoseTracker


Expand All @@ -18,39 +19,111 @@ def parse_args():
help='path of mmdeploy SDK model dumped by model converter')
parser.add_argument('video', help='video path or camera index')
parser.add_argument('--output_dir', help='output directory', default=None)
parser.add_argument(
'--skeleton',
default='coco',
choices=['coco', 'coco_wholebody'],
help='skeleton for keypoints')

args = parser.parse_args()
if args.video.isnumeric():
args.video = int(args.video)
return args


def visualize(frame, results, output_dir, frame_id, thr=0.5, resize=1280):
skeleton = [(15, 13), (13, 11), (16, 14), (14, 12), (11, 12), (5, 11),
(6, 12), (5, 6), (5, 7), (6, 8), (7, 9), (8, 10), (1, 2),
(0, 1), (0, 2), (1, 3), (2, 4), (3, 5), (4, 6)]
palette = [(255, 128, 0), (255, 153, 51), (255, 178, 102), (230, 230, 0),
(255, 153, 255), (153, 204, 255), (255, 102, 255),
(255, 51, 255), (102, 178, 255),
(51, 153, 255), (255, 153, 153), (255, 102, 102), (255, 51, 51),
(153, 255, 153), (102, 255, 102), (51, 255, 51), (0, 255, 0),
(0, 0, 255), (255, 0, 0), (255, 255, 255)]
link_color = [
0, 0, 0, 0, 7, 7, 7, 9, 9, 9, 9, 9, 16, 16, 16, 16, 16, 16, 16
]
point_color = [16, 16, 16, 16, 16, 9, 9, 9, 9, 9, 9, 0, 0, 0, 0, 0, 0]
VISUALIZATION_CFG = dict(
coco=dict(
skeleton=[(15, 13), (13, 11), (16, 14), (14, 12), (11, 12), (5, 11),
(6, 12), (5, 6), (5, 7), (6, 8), (7, 9), (8, 10), (1, 2),
(0, 1), (0, 2), (1, 3), (2, 4), (3, 5), (4, 6)],
palette=[(255, 128, 0), (255, 153, 51), (255, 178, 102), (230, 230, 0),
(255, 153, 255), (153, 204, 255), (255, 102, 255),
(255, 51, 255), (102, 178, 255), (51, 153, 255),
(255, 153, 153), (255, 102, 102), (255, 51, 51),
(153, 255, 153), (102, 255, 102), (51, 255, 51), (0, 255, 0),
(0, 0, 255), (255, 0, 0), (255, 255, 255)],
link_color=[
0, 0, 0, 0, 7, 7, 7, 9, 9, 9, 9, 9, 16, 16, 16, 16, 16, 16, 16
],
point_color=[16, 16, 16, 16, 16, 9, 9, 9, 9, 9, 9, 0, 0, 0, 0, 0, 0],
sigmas=[
0.026, 0.025, 0.025, 0.035, 0.035, 0.079, 0.079, 0.072, 0.072,
0.062, 0.062, 0.107, 0.107, 0.087, 0.087, 0.089, 0.089
]),
coco_wholebody=dict(
skeleton=[(15, 13), (13, 11), (16, 14), (14, 12), (11, 12), (5, 11),
(6, 12), (5, 6), (5, 7), (6, 8), (7, 9), (8, 10), (1, 2),
(0, 1), (0, 2), (1, 3), (2, 4), (3, 5), (4, 6), (15, 17),
(15, 18), (15, 19), (16, 20), (16, 21), (16, 22), (91, 92),
(92, 93), (93, 94), (94, 95), (91, 96), (96, 97), (97, 98),
(98, 99), (91, 100), (100, 101), (101, 102), (102, 103),
(91, 104), (104, 105), (105, 106), (106, 107), (91, 108),
(108, 109), (109, 110), (110, 111), (112, 113), (113, 114),
(114, 115), (115, 116), (112, 117), (117, 118), (118, 119),
(119, 120), (112, 121), (121, 122), (122, 123), (123, 124),
(112, 125), (125, 126), (126, 127), (127, 128), (112, 129),
(129, 130), (130, 131), (131, 132)],
palette=[(51, 153, 255), (0, 255, 0), (255, 128, 0), (255, 255, 255),
(255, 153, 255), (102, 178, 255), (255, 51, 51)],
link_color=[
1, 1, 2, 2, 0, 0, 0, 0, 1, 2, 1, 2, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1,
2, 2, 2, 2, 2, 2, 2, 4, 4, 4, 4, 5, 5, 5, 5, 6, 6, 6, 6, 1, 1, 1,
1, 2, 2, 2, 2, 4, 4, 4, 4, 5, 5, 5, 5, 6, 6, 6, 6, 1, 1, 1, 1
],
point_color=[
0, 0, 0, 0, 0, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 1, 2, 2, 2, 2, 2, 2,
2, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
3, 3, 3, 3, 2, 2, 2, 2, 4, 4, 4, 4, 5, 5, 5, 5, 6, 6, 6, 6, 1, 1,
1, 1, 3, 2, 2, 2, 2, 4, 4, 4, 4, 5, 5, 5, 5, 6, 6, 6, 6, 1, 1, 1, 1
],
sigmas=[
0.026, 0.025, 0.025, 0.035, 0.035, 0.079, 0.079, 0.072, 0.072,
0.062, 0.062, 0.107, 0.107, 0.087, 0.087, 0.089, 0.089, 0.068,
0.066, 0.066, 0.092, 0.094, 0.094, 0.042, 0.043, 0.044, 0.043,
0.040, 0.035, 0.031, 0.025, 0.020, 0.023, 0.029, 0.032, 0.037,
0.038, 0.043, 0.041, 0.045, 0.013, 0.012, 0.011, 0.011, 0.012,
0.012, 0.011, 0.011, 0.013, 0.015, 0.009, 0.007, 0.007, 0.007,
0.012, 0.009, 0.008, 0.016, 0.010, 0.017, 0.011, 0.009, 0.011,
0.009, 0.007, 0.013, 0.008, 0.011, 0.012, 0.010, 0.034, 0.008,
0.008, 0.009, 0.008, 0.008, 0.007, 0.010, 0.008, 0.009, 0.009,
0.009, 0.007, 0.007, 0.008, 0.011, 0.008, 0.008, 0.008, 0.01,
0.008, 0.029, 0.022, 0.035, 0.037, 0.047, 0.026, 0.025, 0.024,
0.035, 0.018, 0.024, 0.022, 0.026, 0.017, 0.021, 0.021, 0.032,
0.02, 0.019, 0.022, 0.031, 0.029, 0.022, 0.035, 0.037, 0.047,
0.026, 0.025, 0.024, 0.035, 0.018, 0.024, 0.022, 0.026, 0.017,
0.021, 0.021, 0.032, 0.02, 0.019, 0.022, 0.031
]))


def visualize(frame,
results,
output_dir,
frame_id,
thr=0.5,
resize=1280,
skeleton_type='coco'):

skeleton = VISUALIZATION_CFG[skeleton_type]['skeleton']
palette = VISUALIZATION_CFG[skeleton_type]['palette']
link_color = VISUALIZATION_CFG[skeleton_type]['link_color']
point_color = VISUALIZATION_CFG[skeleton_type]['point_color']

scale = resize / max(frame.shape[0], frame.shape[1])
keypoints, bboxes, _ = results
scores = keypoints[..., 2]
keypoints = (keypoints[..., :2] * scale).astype(int)
bboxes *= scale
img = cv2.resize(frame, (0, 0), fx=scale, fy=scale)
for kpts, score, bbox in zip(keypoints, scores, bboxes):
show = [0] * len(kpts)
show = [1] * len(kpts)
for (u, v), color in zip(skeleton, link_color):
if score[u] > thr and score[v] > thr:
cv2.line(img, kpts[u], tuple(kpts[v]), palette[color], 1,
cv2.LINE_AA)
show[u] = show[v] = 1
else:
show[u] = show[v] = 0
for kpt, show, color in zip(kpts, show, point_color):
if show:
cv2.circle(img, kpt, 1, palette[color], 2, cv2.LINE_AA)
Expand All @@ -64,7 +137,7 @@ def visualize(frame, results, output_dir, frame_id, thr=0.5, resize=1280):

def main():
args = parse_args()

np.set_printoptions(precision=4, suppress=True)
video = cv2.VideoCapture(args.video)

tracker = PoseTracker(
Expand All @@ -73,12 +146,9 @@ def main():
device_name=args.device_name)

# optionally use OKS for keypoints similarity comparison
coco_sigmas = [
0.026, 0.025, 0.025, 0.035, 0.035, 0.079, 0.079, 0.072, 0.072, 0.062,
0.062, 0.107, 0.107, 0.087, 0.087, 0.089, 0.089
]
sigmas = VISUALIZATION_CFG[args.skeleton]['sigmas']
state = tracker.create_state(
det_interval=1, det_min_bbox_size=100, keypoint_sigmas=coco_sigmas)
det_interval=1, det_min_bbox_size=100, keypoint_sigmas=sigmas)

if args.output_dir:
os.makedirs(args.output_dir, exist_ok=True)
Expand All @@ -89,7 +159,12 @@ def main():
if not success:
break
results = tracker(state, frame, detect=-1)
if not visualize(frame, results, args.output_dir, frame_id):
if not visualize(
frame,
results,
args.output_dir,
frame_id,
skeleton_type=args.skeleton):
break
frame_id += 1

Expand Down

0 comments on commit 01a88be

Please sign in to comment.