From 2ea619e286b726f030881f6dbcb6e0acff6b178e Mon Sep 17 00:00:00 2001 From: yihong1120 Date: Mon, 8 Jul 2024 23:10:08 +0800 Subject: [PATCH] Add test file --- src/stream_capture.py | 47 ++++---- tests/stream_caputre_test.py | 200 +++++++++++++++++++++++++++++++++++ 2 files changed, 219 insertions(+), 28 deletions(-) create mode 100644 tests/stream_caputre_test.py diff --git a/src/stream_capture.py b/src/stream_capture.py index 258e8c5..1ecb466 100644 --- a/src/stream_capture.py +++ b/src/stream_capture.py @@ -5,19 +5,11 @@ import gc import time from collections.abc import Generator -from typing import TypedDict, Optional, Tuple import cv2 import speedtest import streamlink -class FrameData(TypedDict): - frame: cv2.Mat - timestamp: float - -class SpeedData(TypedDict): - download_speed: float - upload_speed: float class StreamCapture: """ @@ -32,7 +24,7 @@ def __init__(self, stream_url: str, capture_interval: int = 15): stream_url (str): The URL of the video stream. """ self.stream_url = stream_url - self.cap: Optional[cv2.VideoCapture] = None + self.cap: cv2.VideoCapture | None = None self.capture_interval = capture_interval def initialise_stream(self) -> None: @@ -55,12 +47,12 @@ def release_resources(self) -> None: self.cap = None gc.collect() - def capture_frames(self) -> Generator[FrameData, None, None]: + def capture_frames(self) -> Generator[tuple[cv2.Mat, float], None, None]: """ Captures frames from the stream and yields them with timestamps. Yields: - FrameData: The captured frame and the timestamp. + Tuple[cv2.Mat, float]: The captured frame and the timestamp. """ self.initialise_stream() last_process_time = datetime.datetime.now() - datetime.timedelta( @@ -86,7 +78,7 @@ def capture_frames(self) -> Generator[FrameData, None, None]: if elapsed_time >= self.capture_interval: last_process_time = current_time timestamp = current_time.timestamp() - yield {'frame': frame, 'timestamp': timestamp} + yield frame, timestamp # Clear memory del frame, timestamp @@ -96,31 +88,30 @@ def capture_frames(self) -> Generator[FrameData, None, None]: self.release_resources() - def check_internet_speed(self) -> SpeedData: + def check_internet_speed(self) -> tuple[float, float]: """ Checks internet speed using the Speedtest library. Returns: - SpeedData: The download and upload speeds in Mbps. + Tuple[float, float]: Download and upload speeds (Mbps). """ st = speedtest.Speedtest() st.get_best_server() download_speed = st.download() / 1_000_000 upload_speed = st.upload() / 1_000_000 - return {'download_speed': download_speed, 'upload_speed': upload_speed} + return download_speed, upload_speed - def select_quality_based_on_speed(self) -> Optional[str]: + def select_quality_based_on_speed(self) -> str | None: """ Selects stream quality based on internet speed. Returns: - Optional[str]: The URL of the selected stream quality. + str: The URL of the selected stream quality. Raises: - Exception: If no compatible stream quality is available. + Exception: If compatible stream quality is not available. """ - speed_data = self.check_internet_speed() - download_speed = speed_data['download_speed'] + download_speed, _ = self.check_internet_speed() try: streams = streamlink.streams(self.stream_url) available_qualities = list(streams.keys()) @@ -154,12 +145,12 @@ def select_quality_based_on_speed(self) -> Optional[str]: def capture_youtube_frames( self, - ) -> Generator[FrameData, None, None]: + ) -> Generator[tuple[cv2.Mat, float], None, None]: """ Captures frames from a YouTube stream. Yields: - FrameData: The captured frame and the timestamp. + Tuple[cv2.Mat, float]: The captured frame and the timestamp. """ stream_url = self.select_quality_based_on_speed() if not stream_url: @@ -183,7 +174,7 @@ def capture_youtube_frames( ).total_seconds() >= self.capture_interval: last_process_time = current_time timestamp = current_time.timestamp() - yield {'frame': frame, 'timestamp': timestamp} + yield frame, timestamp # 清理內存 del frame, timestamp @@ -195,12 +186,12 @@ def capture_youtube_frames( finally: self.release_resources() - def execute_capture(self) -> Generator[FrameData, None, None]: + def execute_capture(self) -> Generator[tuple[cv2.Mat, float], None, None]: """ Returns capture generator for stream type. Returns: - Generator[FrameData]: Yields frames and timestamps. + Generator[Tuple[cv2.Mat, float]]: Yields frames and timestamps. """ if ( 'youtube.com' in self.stream_url.lower() @@ -233,9 +224,9 @@ def update_capture_interval(self, new_interval: int) -> None: args = parser.parse_args() stream_capture = StreamCapture(args.url) - for frame_data in stream_capture.execute_capture(): + for frame, timestamp in stream_capture.execute_capture(): # Process the frame here - print(f"Frame at {frame_data['timestamp']} displayed") + print(f"Frame at {timestamp} displayed") # Release the frame resources - del frame_data['frame'] + del frame gc.collect() diff --git a/tests/stream_caputre_test.py b/tests/stream_caputre_test.py new file mode 100644 index 0000000..e44ee31 --- /dev/null +++ b/tests/stream_caputre_test.py @@ -0,0 +1,200 @@ +from __future__ import annotations + +import datetime +import gc +import time +import unittest +from unittest.mock import MagicMock +from unittest.mock import patch + +import cv2 + +from src.stream_capture import StreamCapture + + +class TestStreamCapture(unittest.TestCase): + def setUp(self): + self.stream_url = 'https://www.youtube.com/watch?v=mf-1VZ6ewlE' + self.capture_interval = 15 + self.stream_capture = StreamCapture( + self.stream_url, self.capture_interval, + ) + + @patch('src.stream_capture.cv2.VideoCapture') + def test_initialise_stream(self, mock_video_capture): + mock_cap_instance = MagicMock() + mock_video_capture.return_value = mock_cap_instance + + self.stream_capture.initialise_stream() + + mock_video_capture.assert_called_with(self.stream_url) + mock_cap_instance.set.assert_any_call(cv2.CAP_PROP_BUFFERSIZE, 1) + mock_cap_instance.set.assert_any_call( + cv2.CAP_PROP_FOURCC, cv2.VideoWriter_fourcc(*'H264'), + ) + self.assertTrue(mock_cap_instance.isOpened.called) + + @patch('src.stream_capture.cv2.VideoCapture') + def test_release_resources(self, mock_video_capture): + mock_cap_instance = MagicMock() + mock_video_capture.return_value = mock_cap_instance + self.stream_capture.cap = mock_cap_instance + + self.stream_capture.release_resources() + + mock_cap_instance.release.assert_called_once() + self.assertIsNone(self.stream_capture.cap) + gc.collect() + + @patch('src.stream_capture.cv2.VideoCapture') + @patch('src.stream_capture.datetime') + @patch('src.stream_capture.gc.collect') + def test_capture_frames( + self, mock_gc_collect, mock_datetime, mock_video_capture, + ): + mock_cap_instance = MagicMock() + mock_video_capture.return_value = mock_cap_instance + self.stream_capture.cap = mock_cap_instance + + mock_now = datetime.datetime(2023, 1, 1, 0, 0, 0) + mock_later = mock_now + \ + datetime.timedelta(seconds=self.capture_interval + 1) + mock_datetime.datetime.now.side_effect = [ + mock_now, + mock_now, + mock_later, + ] + + mock_cap_instance.read.side_effect = [ + (True, MagicMock()), (True, MagicMock()), + ] + + frame_generator = self.stream_capture.capture_frames() + + frame, timestamp = next(frame_generator) + self.assertIsNotNone(frame) + self.assertEqual(timestamp, mock_now.timestamp()) + + frame, timestamp = next(frame_generator) + self.assertIsNotNone(frame) + self.assertEqual(timestamp, mock_later.timestamp()) + + self.stream_capture.release_resources() + mock_gc_collect.assert_called() + + @patch('src.stream_capture.speedtest.Speedtest') + def test_check_internet_speed(self, mock_speedtest): + mock_st_instance = MagicMock() + mock_speedtest.return_value = mock_st_instance + + mock_st_instance.download.return_value = 100_000_000 + mock_st_instance.upload.return_value = 50_000_000 + + download_speed, upload_speed = ( + self.stream_capture.check_internet_speed() + ) + + self.assertEqual(download_speed, 100.0) + self.assertEqual(upload_speed, 50.0) + + @patch('src.stream_capture.streamlink.streams') + @patch('src.stream_capture.StreamCapture.check_internet_speed') + def test_select_quality_based_on_speed( + self, mock_check_internet_speed, mock_streams, + ): + mock_check_internet_speed.return_value = (15.0, 5.0) + mock_streams.return_value = { + 'best': MagicMock(url=self.stream_url), + '1080p': MagicMock(url=self.stream_url), + '720p': MagicMock(url=self.stream_url), + } + + selected_quality = self.stream_capture.select_quality_based_on_speed() + self.assertEqual(selected_quality, self.stream_url) + + @patch('src.stream_capture.streamlink.streams') + @patch('src.stream_capture.StreamCapture.check_internet_speed') + def test_select_quality_based_on_speed_no_quality( + self, mock_check_internet_speed, mock_streams, + ): + mock_check_internet_speed.return_value = (2.0, 1.0) + mock_streams.return_value = { + '720p': MagicMock(url=self.stream_url), + '480p': MagicMock(url=self.stream_url), + } + + selected_quality = self.stream_capture.select_quality_based_on_speed() + self.assertEqual(selected_quality, self.stream_url) + + @patch('src.stream_capture.cv2.VideoCapture') + @patch('src.stream_capture.streamlink.streams') + @patch('src.stream_capture.StreamCapture.check_internet_speed') + def test_capture_youtube_frames( + self, mock_check_internet_speed, mock_streams, mock_video_capture, + ): + mock_check_internet_speed.return_value = (15.0, 5.0) + mock_streams.return_value = { + 'best': MagicMock(url=self.stream_url), + '1080p': MagicMock(url=self.stream_url), + '720p': MagicMock(url=self.stream_url), + } + mock_cap_instance = MagicMock() + mock_video_capture.return_value = mock_cap_instance + self.stream_capture.cap = mock_cap_instance + + mock_cap_instance.read.side_effect = [ + (True, MagicMock()), (True, MagicMock()), + ] + + frame_generator = self.stream_capture.capture_youtube_frames() + + try: + frame, timestamp = next(frame_generator) + self.assertIsNotNone(frame) + self.assertIsInstance(timestamp, float) + except StopIteration: + self.fail('Generator stopped unexpectedly') + + mock_now_later = datetime.datetime.now( + ) + datetime.timedelta(seconds=self.capture_interval + 1) + with patch( + 'src.stream_capture.datetime.datetime.now', + return_value=mock_now_later, + ): + frame, timestamp = next(frame_generator) + self.assertIsNotNone(frame) + self.assertIsInstance(timestamp, float) + + self.stream_capture.release_resources() + + @patch('src.stream_capture.StreamCapture.capture_frames') + def test_execute_capture(self, mock_capture_frames): + expected_timestamp = time.time() # 使用当前时间戳 + mock_capture_frames.return_value = iter( + [(MagicMock(), expected_timestamp)], + ) + frames_generator = self.stream_capture.execute_capture() + + frame, timestamp = next(frames_generator) + self.assertIsNotNone(frame) + self.assertAlmostEqual( + timestamp, expected_timestamp, delta=40, + ) # 允许更大的差异 + + @patch('src.stream_capture.StreamCapture.capture_youtube_frames') + def test_execute_capture_youtube(self, mock_capture_youtube_frames): + self.stream_capture.stream_url = ( + 'https://www.youtube.com/watch?v=mf-1VZ6ewlE' + ) + mock_capture_youtube_frames.return_value = iter( + [(MagicMock(), 1234567890.0)], + ) + frames_generator = self.stream_capture.execute_capture() + + frame, timestamp = next(frames_generator) + self.assertIsNotNone(frame) + self.assertAlmostEqual(timestamp, 1234567890.0, delta=40) + + +if __name__ == '__main__': + unittest.main()