Skip to content

Commit

Permalink
Add test file
Browse files Browse the repository at this point in the history
  • Loading branch information
yihong1120 committed Jul 8, 2024
1 parent 81c5fe4 commit 2a33c15
Show file tree
Hide file tree
Showing 2 changed files with 132 additions and 54 deletions.
67 changes: 13 additions & 54 deletions src/live_stream_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,27 +3,10 @@
import argparse
import datetime
from collections.abc import Generator
from typing import TypedDict, List, Tuple

import cv2
from ultralytics import YOLO

class BoundingBox(TypedDict):
x1: float
y1: float
x2: float
y2: float

class DetectionData(TypedDict):
bbox: BoundingBox
confidence: float
class_label: int

class DetectionResult(TypedDict):
ids: List[int]
datas: List[DetectionData]
frame: cv2.Mat
timestamp: float

class LiveStreamDetector:
"""
Expand All @@ -47,13 +30,13 @@ def __init__(
self.model = YOLO(self.model_path)
self.cap = cv2.VideoCapture(self.stream_url)

def generate_detections(self) -> Generator[DetectionResult, None, None]:
def generate_detections(self) -> Generator[tuple, None, None]:
"""
Yields detection results, timestamp per frame from video capture.
Yields:
Generator[DetectionResult]: Detection result including detection ids,
detection data, frame, and the current timestamp for each video frame.
Generator[Tuple]: Tuple of detection ids, detection data, frame,
and the current timestamp for each video frame.
"""
while self.cap.isOpened():
success, frame = self.cap.read()
Expand Down Expand Up @@ -85,35 +68,10 @@ def generate_detections(self) -> Generator[DetectionResult, None, None]:
else []
)

# Convert datas_list to DetectionData format
detection_datas = [
{
'bbox': {
'x1': data[0],
'y1': data[1],
'x2': data[2],
'y2': data[3],
},
'confidence': data[4],
'class_label': int(data[5]),
}
for data in datas_list
]

# Yield the results
yield {
'ids': ids_list,
'datas': detection_datas,
'frame': frame,
'timestamp': timestamp,
}
yield ids_list, datas_list, frame, timestamp
else:
yield {
'ids': [],
'datas': [],
'frame': frame,
'timestamp': timestamp,
}
yield [], [], frame, timestamp

# Exit loop if 'q' is pressed
if cv2.waitKey(1) & 0xFF == ord('q'):
Expand All @@ -130,16 +88,16 @@ def run_detection(self):
"""
Runs the live stream detection and prints out detection results.
"""
for result in self.generate_detections():
for ids, datas, frame, timestamp in self.generate_detections():
print(
'Timestamp:', datetime.datetime.fromtimestamp(
result['timestamp'],
timestamp,
).strftime('%Y-%m-%d %H:%M:%S'),
)
print('IDs:', result['ids'])
print('IDs:', ids)
print('Data (xyxy format):')
for data in result['datas']:
print(data)
print(datas)


if __name__ == '__main__':
parser = argparse.ArgumentParser(
Expand Down Expand Up @@ -168,5 +126,6 @@ def run_detection(self):
# Release resources after detection is complete
detector.release_resources()

# Example usage
# python live_stream_tracker.py --url https://cctv6.kctmc.nat.gov.tw/ea05668e/
"""example
python live_stream_tracker.py --url https://cctv6.kctmc.nat.gov.tw/ea05668e/
"""
119 changes: 119 additions & 0 deletions tests/live_stream_tracker_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
from __future__ import annotations

import datetime
import unittest
from unittest.mock import MagicMock
from unittest.mock import patch

import numpy as np

from src.live_stream_tracker import LiveStreamDetector


class TestLiveStreamDetector(unittest.TestCase):
def setUp(self):
self.stream_url = 'https://cctv6.kctmc.nat.gov.tw/ea05668e/'
self.model_path = 'models/pt/best_yolov8n.pt'
self.detector = LiveStreamDetector(self.stream_url, self.model_path)

@patch('src.live_stream_tracker.YOLO')
@patch('src.live_stream_tracker.cv2.VideoCapture')
def test_initialisation(self, mock_video_capture, mock_yolo):
mock_cap_instance = MagicMock()
mock_video_capture.return_value = mock_cap_instance

detector = LiveStreamDetector(self.stream_url, self.model_path)

self.assertEqual(detector.stream_url, self.stream_url)
self.assertEqual(detector.model_path, self.model_path)
self.assertEqual(detector.model, mock_yolo.return_value)
self.assertEqual(detector.cap, mock_cap_instance)

@patch('src.live_stream_tracker.YOLO')
@patch('src.live_stream_tracker.cv2.VideoCapture')
@patch('src.live_stream_tracker.datetime')
def test_generate_detections(
self, mock_datetime, mock_video_capture, mock_yolo,
):
mock_cap_instance = MagicMock()
mock_video_capture.return_value = mock_cap_instance
mock_cap_instance.isOpened.side_effect = [True, True, False]
mock_cap_instance.read.side_effect = [
(
True, np.zeros(
(480, 640, 3), dtype=np.uint8,
),
), (True, np.zeros((480, 640, 3), dtype=np.uint8)), (False, None),
]

mock_yolo_instance = MagicMock()
mock_yolo.return_value = mock_yolo_instance
mock_results = MagicMock()
mock_results[0].boxes = MagicMock()
mock_results[0].boxes.id = MagicMock()
mock_results[0].boxes.data = MagicMock()
mock_results[0].boxes.id.numpy.return_value = [1, 2, 3]
mock_results[0].boxes.data.numpy.return_value = [
[0.1, 0.2, 0.3, 0.4], [0.5, 0.6, 0.7, 0.8],
]
mock_yolo_instance.track.return_value = mock_results

mock_now = datetime.datetime(2023, 1, 1, 0, 0, 0)
mock_datetime.datetime.now.side_effect = [mock_now, mock_now]

frame_generator = self.detector.generate_detections()

ids, datas, frame, timestamp = next(frame_generator)
self.assertIsInstance(ids, list)
self.assertIsInstance(datas, list)
self.assertIsInstance(frame, np.ndarray)
self.assertIsInstance(timestamp, float)

@patch('src.live_stream_tracker.cv2.VideoCapture')
@patch('src.live_stream_tracker.cv2.destroyAllWindows')
def test_release_resources(
self, mock_destroy_all_windows, mock_video_capture,
):
mock_cap_instance = MagicMock()
mock_video_capture.return_value = mock_cap_instance
self.detector.cap = mock_cap_instance

self.detector.release_resources()

mock_cap_instance.release.assert_called_once()
mock_destroy_all_windows.assert_called_once()

@patch('src.live_stream_tracker.LiveStreamDetector.generate_detections')
def test_run_detection(self, mock_generate_detections):
mock_generate_detections.return_value = iter(
[(
[1, 2, 3], [[0.1, 0.2, 0.3, 0.4]], np.zeros(
(480, 640, 3), dtype=np.uint8,
), 1234567890.0,
)],
)

with patch('builtins.print') as mock_print:
self.detector.run_detection()
self.assertTrue(
any(
'Timestamp:' in str(call) and '2009-02-14' in str(call)
for call in mock_print.call_args_list
),
)
self.assertTrue(
any(
'IDs:' in str(call)
for call in mock_print.call_args_list
),
)
self.assertTrue(
any(
'Data (xyxy format):' in str(call)
for call in mock_print.call_args_list
),
)


if __name__ == '__main__':
unittest.main()

0 comments on commit 2a33c15

Please sign in to comment.