Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Keep track of detection index #299

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 51 additions & 8 deletions yolox/tracker/byte_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,14 @@

class STrack(BaseTrack):
shared_kalman = KalmanFilter()
def __init__(self, tlwh, score):
def __init__(self, tlwh, score, det_idx=None):
"""
Initialize a tracklet
Args:
tlwh: (np.ndarray) bbox in format x1,y1,x2,y2
score: (float) bbox detection score
det_idx: (int) (Optional) corresponding index in object detection list
"""

# wait activate
self._tlwh = np.asarray(tlwh, dtype=np.float)
Expand All @@ -23,6 +30,8 @@ def __init__(self, tlwh, score):
self.score = score
self.tracklet_len = 0

self.det_idx = det_idx

def predict(self):
mean_state = self.mean.copy()
if self.state != TrackState.Tracked:
Expand Down Expand Up @@ -67,7 +76,8 @@ def re_activate(self, new_track, frame_id, new_id=False):
if new_id:
self.track_id = self.next_id()
self.score = new_track.score

self.det_idx = new_track.det_idx

def update(self, new_track, frame_id):
"""
Update a matched track
Expand All @@ -87,6 +97,8 @@ def update(self, new_track, frame_id):

self.score = new_track.score

self.det_idx = new_track.det_idx

@property
# @jit(nopython=True)
def tlwh(self):
Expand Down Expand Up @@ -156,17 +168,33 @@ def __init__(self, args, frame_rate=30):
self.max_time_lost = self.buffer_size
self.kalman_filter = KalmanFilter()

def update(self, output_results, img_info, img_size):
def update(self, output_results, img_info, img_size, track_det_idx=False):
"""
Update tracker with detection results
Args:
output_results: detection results + Scores + det_idx (optional)
img_info: original image information
img_size: inference scaled image size
track_det_idx: whether to track det_idx (index corresponding to the detection results)
"""
self.frame_id += 1
activated_starcks = []
refind_stracks = []
lost_stracks = []
removed_stracks = []

if track_det_idx:
if output_results.shape[1] == 6:
scores = output_results[:, 4]
bboxes = output_results[:, :4] # x1y1x2y2
_det_idxs = output_results[:, 5]
else:
raise ValueError('output_results shape error')

if output_results.shape[1] == 5:
scores = output_results[:, 4]
bboxes = output_results[:, :4]
else:
elif not track_det_idx:
output_results = output_results.cpu().numpy()
scores = output_results[:, 4] * output_results[:, 5]
bboxes = output_results[:, :4] # x1y1x2y2
Expand All @@ -183,11 +211,19 @@ def update(self, output_results, img_info, img_size):
dets = bboxes[remain_inds]
scores_keep = scores[remain_inds]
scores_second = scores[inds_second]
if track_det_idx:
det_idxs = _det_idxs[remain_inds]
det_idxs_second = _det_idxs[inds_second]

if len(dets) > 0:
'''Detections'''
detections = [STrack(STrack.tlbr_to_tlwh(tlbr), s) for
(tlbr, s) in zip(dets, scores_keep)]
detections = []
for i, (tlbr, s) in enumerate(zip(dets, scores_keep)):
di = det_idxs[i]
if track_det_idx:
detections.append(STrack(STrack.tlbr_to_tlwh(tlbr), s, di))
else:
detections.append(STrack(STrack.tlbr_to_tlwh(tlbr), s))
else:
detections = []

Expand Down Expand Up @@ -223,8 +259,15 @@ def update(self, output_results, img_info, img_size):
# association the untrack to the low score detections
if len(dets_second) > 0:
'''Detections'''
detections_second = [STrack(STrack.tlbr_to_tlwh(tlbr), s) for
(tlbr, s) in zip(dets_second, scores_second)]
#detections_second = [STrack(STrack.tlbr_to_tlwh(tlbr), s, di) for
# (tlbr, s, di) in zip(dets_second, scores_second, det_idxs_second)]
detections_second = []
for i, (tlbr, s) in enumerate(zip(dets_second, scores_second)):
if track_det_idx:
di = det_idxs_second[i]
detections_second.append(STrack(STrack.tlbr_to_tlwh(tlbr), s, di))
else:
detections_second.append(STrack(STrack.tlbr_to_tlwh(tlbr), s))
else:
detections_second = []
r_tracked_stracks = [strack_pool[i] for i in u_track if strack_pool[i].state == TrackState.Tracked]
Expand Down