diff --git a/src/live_stream_tracker.py b/src/live_stream_tracker.py index ffa5791..059efde 100644 --- a/src/live_stream_tracker.py +++ b/src/live_stream_tracker.py @@ -3,27 +3,10 @@ import argparse import datetime from collections.abc import Generator -from typing import TypedDict, List, Tuple import cv2 from ultralytics import YOLO -class BoundingBox(TypedDict): - x1: float - y1: float - x2: float - y2: float - -class DetectionData(TypedDict): - bbox: BoundingBox - confidence: float - class_label: int - -class DetectionResult(TypedDict): - ids: List[int] - datas: List[DetectionData] - frame: cv2.Mat - timestamp: float class LiveStreamDetector: """ @@ -47,13 +30,13 @@ def __init__( self.model = YOLO(self.model_path) self.cap = cv2.VideoCapture(self.stream_url) - def generate_detections(self) -> Generator[DetectionResult, None, None]: + def generate_detections(self) -> Generator[tuple, None, None]: """ Yields detection results, timestamp per frame from video capture. Yields: - Generator[DetectionResult]: Detection result including detection ids, - detection data, frame, and the current timestamp for each video frame. + Generator[Tuple]: Tuple of detection ids, detection data, frame, + and the current timestamp for each video frame. """ while self.cap.isOpened(): success, frame = self.cap.read() @@ -85,35 +68,10 @@ def generate_detections(self) -> Generator[DetectionResult, None, None]: else [] ) - # Convert datas_list to DetectionData format - detection_datas = [ - { - 'bbox': { - 'x1': data[0], - 'y1': data[1], - 'x2': data[2], - 'y2': data[3], - }, - 'confidence': data[4], - 'class_label': int(data[5]), - } - for data in datas_list - ] - # Yield the results - yield { - 'ids': ids_list, - 'datas': detection_datas, - 'frame': frame, - 'timestamp': timestamp, - } + yield ids_list, datas_list, frame, timestamp else: - yield { - 'ids': [], - 'datas': [], - 'frame': frame, - 'timestamp': timestamp, - } + yield [], [], frame, timestamp # Exit loop if 'q' is pressed if cv2.waitKey(1) & 0xFF == ord('q'): @@ -130,16 +88,16 @@ def run_detection(self): """ Runs the live stream detection and prints out detection results. """ - for result in self.generate_detections(): + for ids, datas, frame, timestamp in self.generate_detections(): print( 'Timestamp:', datetime.datetime.fromtimestamp( - result['timestamp'], + timestamp, ).strftime('%Y-%m-%d %H:%M:%S'), ) - print('IDs:', result['ids']) + print('IDs:', ids) print('Data (xyxy format):') - for data in result['datas']: - print(data) + print(datas) + if __name__ == '__main__': parser = argparse.ArgumentParser( @@ -168,5 +126,6 @@ def run_detection(self): # Release resources after detection is complete detector.release_resources() -# Example usage -# python live_stream_tracker.py --url https://cctv6.kctmc.nat.gov.tw/ea05668e/ +"""example +python live_stream_tracker.py --url https://cctv6.kctmc.nat.gov.tw/ea05668e/ +""" diff --git a/tests/live_stream_tracker_test.py b/tests/live_stream_tracker_test.py new file mode 100644 index 0000000..46b26e9 --- /dev/null +++ b/tests/live_stream_tracker_test.py @@ -0,0 +1,119 @@ +from __future__ import annotations + +import datetime +import unittest +from unittest.mock import MagicMock +from unittest.mock import patch + +import numpy as np + +from src.live_stream_tracker import LiveStreamDetector + + +class TestLiveStreamDetector(unittest.TestCase): + def setUp(self): + self.stream_url = 'https://cctv6.kctmc.nat.gov.tw/ea05668e/' + self.model_path = 'models/pt/best_yolov8n.pt' + self.detector = LiveStreamDetector(self.stream_url, self.model_path) + + @patch('src.live_stream_tracker.YOLO') + @patch('src.live_stream_tracker.cv2.VideoCapture') + def test_initialisation(self, mock_video_capture, mock_yolo): + mock_cap_instance = MagicMock() + mock_video_capture.return_value = mock_cap_instance + + detector = LiveStreamDetector(self.stream_url, self.model_path) + + self.assertEqual(detector.stream_url, self.stream_url) + self.assertEqual(detector.model_path, self.model_path) + self.assertEqual(detector.model, mock_yolo.return_value) + self.assertEqual(detector.cap, mock_cap_instance) + + @patch('src.live_stream_tracker.YOLO') + @patch('src.live_stream_tracker.cv2.VideoCapture') + @patch('src.live_stream_tracker.datetime') + def test_generate_detections( + self, mock_datetime, mock_video_capture, mock_yolo, + ): + mock_cap_instance = MagicMock() + mock_video_capture.return_value = mock_cap_instance + mock_cap_instance.isOpened.side_effect = [True, True, False] + mock_cap_instance.read.side_effect = [ + ( + True, np.zeros( + (480, 640, 3), dtype=np.uint8, + ), + ), (True, np.zeros((480, 640, 3), dtype=np.uint8)), (False, None), + ] + + mock_yolo_instance = MagicMock() + mock_yolo.return_value = mock_yolo_instance + mock_results = MagicMock() + mock_results[0].boxes = MagicMock() + mock_results[0].boxes.id = MagicMock() + mock_results[0].boxes.data = MagicMock() + mock_results[0].boxes.id.numpy.return_value = [1, 2, 3] + mock_results[0].boxes.data.numpy.return_value = [ + [0.1, 0.2, 0.3, 0.4], [0.5, 0.6, 0.7, 0.8], + ] + mock_yolo_instance.track.return_value = mock_results + + mock_now = datetime.datetime(2023, 1, 1, 0, 0, 0) + mock_datetime.datetime.now.side_effect = [mock_now, mock_now] + + frame_generator = self.detector.generate_detections() + + ids, datas, frame, timestamp = next(frame_generator) + self.assertIsInstance(ids, list) + self.assertIsInstance(datas, list) + self.assertIsInstance(frame, np.ndarray) + self.assertIsInstance(timestamp, float) + + @patch('src.live_stream_tracker.cv2.VideoCapture') + @patch('src.live_stream_tracker.cv2.destroyAllWindows') + def test_release_resources( + self, mock_destroy_all_windows, mock_video_capture, + ): + mock_cap_instance = MagicMock() + mock_video_capture.return_value = mock_cap_instance + self.detector.cap = mock_cap_instance + + self.detector.release_resources() + + mock_cap_instance.release.assert_called_once() + mock_destroy_all_windows.assert_called_once() + + @patch('src.live_stream_tracker.LiveStreamDetector.generate_detections') + def test_run_detection(self, mock_generate_detections): + mock_generate_detections.return_value = iter( + [( + [1, 2, 3], [[0.1, 0.2, 0.3, 0.4]], np.zeros( + (480, 640, 3), dtype=np.uint8, + ), 1234567890.0, + )], + ) + + with patch('builtins.print') as mock_print: + self.detector.run_detection() + self.assertTrue( + any( + 'Timestamp:' in str(call) and '2009-02-14' in str(call) + for call in mock_print.call_args_list + ), + ) + self.assertTrue( + any( + 'IDs:' in str(call) + for call in mock_print.call_args_list + ), + ) + self.assertTrue( + any( + 'Data (xyxy format):' in str(call) + for call in mock_print.call_args_list + ), + ) + + +if __name__ == '__main__': + unittest.main()