Skip to content

Commit

Permalink
Add TypedDict
Browse files Browse the repository at this point in the history
  • Loading branch information
yihong1120 authored Jul 8, 2024
1 parent fb8600a commit ee2176f
Showing 1 changed file with 54 additions and 13 deletions.
67 changes: 54 additions & 13 deletions src/live_stream_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,27 @@
import argparse
import datetime
from collections.abc import Generator
from typing import TypedDict, List, Tuple

import cv2
from ultralytics import YOLO

class BoundingBox(TypedDict):
x1: float
y1: float
x2: float
y2: float

class DetectionData(TypedDict):
bbox: BoundingBox
confidence: float
class_label: int

class DetectionResult(TypedDict):
ids: List[int]
datas: List[DetectionData]
frame: cv2.Mat
timestamp: float

class LiveStreamDetector:
"""
Expand All @@ -30,13 +47,13 @@ def __init__(
self.model = YOLO(self.model_path)
self.cap = cv2.VideoCapture(self.stream_url)

def generate_detections(self) -> Generator[tuple, None, None]:
def generate_detections(self) -> Generator[DetectionResult, None, None]:
"""
Yields detection results, timestamp per frame from video capture.
Yields:
Generator[Tuple]: Tuple of detection ids, detection data, frame,
and the current timestamp for each video frame.
Generator[DetectionResult]: Detection result including detection ids,
detection data, frame, and the current timestamp for each video frame.
"""
while self.cap.isOpened():
success, frame = self.cap.read()
Expand Down Expand Up @@ -68,10 +85,35 @@ def generate_detections(self) -> Generator[tuple, None, None]:
else []
)

# Convert datas_list to DetectionData format
detection_datas = [
{
'bbox': {
'x1': data[0],
'y1': data[1],
'x2': data[2],
'y2': data[3],
},
'confidence': data[4],
'class_label': int(data[5]),
}
for data in datas_list
]

# Yield the results
yield ids_list, datas_list, frame, timestamp
yield {
'ids': ids_list,
'datas': detection_datas,
'frame': frame,
'timestamp': timestamp,
}
else:
yield [], [], frame, timestamp
yield {
'ids': [],
'datas': [],
'frame': frame,
'timestamp': timestamp,
}

# Exit loop if 'q' is pressed
if cv2.waitKey(1) & 0xFF == ord('q'):
Expand All @@ -88,16 +130,16 @@ def run_detection(self):
"""
Runs the live stream detection and prints out detection results.
"""
for ids, datas, frame, timestamp in self.generate_detections():
for result in self.generate_detections():
print(
'Timestamp:', datetime.datetime.fromtimestamp(
timestamp,
result['timestamp'],
).strftime('%Y-%m-%d %H:%M:%S'),
)
print('IDs:', ids)
print('IDs:', result['ids'])
print('Data (xyxy format):')
print(datas)

for data in result['datas']:
print(data)

if __name__ == '__main__':
parser = argparse.ArgumentParser(
Expand Down Expand Up @@ -126,6 +168,5 @@ def run_detection(self):
# Release resources after detection is complete
detector.release_resources()

"""example
python live_stream_tracker.py --url https://cctv6.kctmc.nat.gov.tw/ea05668e/
"""
# Example usage
# python live_stream_tracker.py --url https://cctv6.kctmc.nat.gov.tw/ea05668e/

0 comments on commit ee2176f

Please sign in to comment.