Skip to content

Commit

Permalink
Add test file
Browse files Browse the repository at this point in the history
  • Loading branch information
yihong1120 committed Jul 8, 2024
1 parent d936195 commit 2ea619e
Show file tree
Hide file tree
Showing 2 changed files with 219 additions and 28 deletions.
47 changes: 19 additions & 28 deletions src/stream_capture.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,11 @@
import gc
import time
from collections.abc import Generator
from typing import TypedDict, Optional, Tuple

import cv2
import speedtest
import streamlink

class FrameData(TypedDict):
frame: cv2.Mat
timestamp: float

class SpeedData(TypedDict):
download_speed: float
upload_speed: float

class StreamCapture:
"""
Expand All @@ -32,7 +24,7 @@ def __init__(self, stream_url: str, capture_interval: int = 15):
stream_url (str): The URL of the video stream.
"""
self.stream_url = stream_url
self.cap: Optional[cv2.VideoCapture] = None
self.cap: cv2.VideoCapture | None = None
self.capture_interval = capture_interval

def initialise_stream(self) -> None:
Expand All @@ -55,12 +47,12 @@ def release_resources(self) -> None:
self.cap = None
gc.collect()

def capture_frames(self) -> Generator[FrameData, None, None]:
def capture_frames(self) -> Generator[tuple[cv2.Mat, float], None, None]:
"""
Captures frames from the stream and yields them with timestamps.
Yields:
FrameData: The captured frame and the timestamp.
Tuple[cv2.Mat, float]: The captured frame and the timestamp.
"""
self.initialise_stream()
last_process_time = datetime.datetime.now() - datetime.timedelta(
Expand All @@ -86,7 +78,7 @@ def capture_frames(self) -> Generator[FrameData, None, None]:
if elapsed_time >= self.capture_interval:
last_process_time = current_time
timestamp = current_time.timestamp()
yield {'frame': frame, 'timestamp': timestamp}
yield frame, timestamp

# Clear memory
del frame, timestamp
Expand All @@ -96,31 +88,30 @@ def capture_frames(self) -> Generator[FrameData, None, None]:

self.release_resources()

def check_internet_speed(self) -> SpeedData:
def check_internet_speed(self) -> tuple[float, float]:
"""
Checks internet speed using the Speedtest library.
Returns:
SpeedData: The download and upload speeds in Mbps.
Tuple[float, float]: Download and upload speeds (Mbps).
"""
st = speedtest.Speedtest()
st.get_best_server()
download_speed = st.download() / 1_000_000
upload_speed = st.upload() / 1_000_000
return {'download_speed': download_speed, 'upload_speed': upload_speed}
return download_speed, upload_speed

def select_quality_based_on_speed(self) -> Optional[str]:
def select_quality_based_on_speed(self) -> str | None:
"""
Selects stream quality based on internet speed.
Returns:
Optional[str]: The URL of the selected stream quality.
str: The URL of the selected stream quality.
Raises:
Exception: If no compatible stream quality is available.
Exception: If compatible stream quality is not available.
"""
speed_data = self.check_internet_speed()
download_speed = speed_data['download_speed']
download_speed, _ = self.check_internet_speed()
try:
streams = streamlink.streams(self.stream_url)
available_qualities = list(streams.keys())
Expand Down Expand Up @@ -154,12 +145,12 @@ def select_quality_based_on_speed(self) -> Optional[str]:

def capture_youtube_frames(
self,
) -> Generator[FrameData, None, None]:
) -> Generator[tuple[cv2.Mat, float], None, None]:
"""
Captures frames from a YouTube stream.
Yields:
FrameData: The captured frame and the timestamp.
Tuple[cv2.Mat, float]: The captured frame and the timestamp.
"""
stream_url = self.select_quality_based_on_speed()
if not stream_url:
Expand All @@ -183,7 +174,7 @@ def capture_youtube_frames(
).total_seconds() >= self.capture_interval:
last_process_time = current_time
timestamp = current_time.timestamp()
yield {'frame': frame, 'timestamp': timestamp}
yield frame, timestamp

# 清理內存
del frame, timestamp
Expand All @@ -195,12 +186,12 @@ def capture_youtube_frames(
finally:
self.release_resources()

def execute_capture(self) -> Generator[FrameData, None, None]:
def execute_capture(self) -> Generator[tuple[cv2.Mat, float], None, None]:
"""
Returns capture generator for stream type.
Returns:
Generator[FrameData]: Yields frames and timestamps.
Generator[Tuple[cv2.Mat, float]]: Yields frames and timestamps.
"""
if (
'youtube.com' in self.stream_url.lower()
Expand Down Expand Up @@ -233,9 +224,9 @@ def update_capture_interval(self, new_interval: int) -> None:
args = parser.parse_args()

stream_capture = StreamCapture(args.url)
for frame_data in stream_capture.execute_capture():
for frame, timestamp in stream_capture.execute_capture():
# Process the frame here
print(f"Frame at {frame_data['timestamp']} displayed")
print(f"Frame at {timestamp} displayed")
# Release the frame resources
del frame_data['frame']
del frame
gc.collect()
200 changes: 200 additions & 0 deletions tests/stream_caputre_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,200 @@
from __future__ import annotations

import datetime
import gc
import time
import unittest
from unittest.mock import MagicMock
from unittest.mock import patch

import cv2

from src.stream_capture import StreamCapture


class TestStreamCapture(unittest.TestCase):
def setUp(self):
self.stream_url = 'https://www.youtube.com/watch?v=mf-1VZ6ewlE'
self.capture_interval = 15
self.stream_capture = StreamCapture(
self.stream_url, self.capture_interval,
)

@patch('src.stream_capture.cv2.VideoCapture')
def test_initialise_stream(self, mock_video_capture):
mock_cap_instance = MagicMock()
mock_video_capture.return_value = mock_cap_instance

self.stream_capture.initialise_stream()

mock_video_capture.assert_called_with(self.stream_url)
mock_cap_instance.set.assert_any_call(cv2.CAP_PROP_BUFFERSIZE, 1)
mock_cap_instance.set.assert_any_call(
cv2.CAP_PROP_FOURCC, cv2.VideoWriter_fourcc(*'H264'),
)
self.assertTrue(mock_cap_instance.isOpened.called)

@patch('src.stream_capture.cv2.VideoCapture')
def test_release_resources(self, mock_video_capture):
mock_cap_instance = MagicMock()
mock_video_capture.return_value = mock_cap_instance
self.stream_capture.cap = mock_cap_instance

self.stream_capture.release_resources()

mock_cap_instance.release.assert_called_once()
self.assertIsNone(self.stream_capture.cap)
gc.collect()

@patch('src.stream_capture.cv2.VideoCapture')
@patch('src.stream_capture.datetime')
@patch('src.stream_capture.gc.collect')
def test_capture_frames(
self, mock_gc_collect, mock_datetime, mock_video_capture,
):
mock_cap_instance = MagicMock()
mock_video_capture.return_value = mock_cap_instance
self.stream_capture.cap = mock_cap_instance

mock_now = datetime.datetime(2023, 1, 1, 0, 0, 0)
mock_later = mock_now + \
datetime.timedelta(seconds=self.capture_interval + 1)
mock_datetime.datetime.now.side_effect = [
mock_now,
mock_now,
mock_later,
]

mock_cap_instance.read.side_effect = [
(True, MagicMock()), (True, MagicMock()),
]

frame_generator = self.stream_capture.capture_frames()

frame, timestamp = next(frame_generator)
self.assertIsNotNone(frame)
self.assertEqual(timestamp, mock_now.timestamp())

frame, timestamp = next(frame_generator)
self.assertIsNotNone(frame)
self.assertEqual(timestamp, mock_later.timestamp())

self.stream_capture.release_resources()
mock_gc_collect.assert_called()

@patch('src.stream_capture.speedtest.Speedtest')
def test_check_internet_speed(self, mock_speedtest):
mock_st_instance = MagicMock()
mock_speedtest.return_value = mock_st_instance

mock_st_instance.download.return_value = 100_000_000
mock_st_instance.upload.return_value = 50_000_000

download_speed, upload_speed = (
self.stream_capture.check_internet_speed()
)

self.assertEqual(download_speed, 100.0)
self.assertEqual(upload_speed, 50.0)

@patch('src.stream_capture.streamlink.streams')
@patch('src.stream_capture.StreamCapture.check_internet_speed')
def test_select_quality_based_on_speed(
self, mock_check_internet_speed, mock_streams,
):
mock_check_internet_speed.return_value = (15.0, 5.0)
mock_streams.return_value = {
'best': MagicMock(url=self.stream_url),
'1080p': MagicMock(url=self.stream_url),
'720p': MagicMock(url=self.stream_url),
}

selected_quality = self.stream_capture.select_quality_based_on_speed()
self.assertEqual(selected_quality, self.stream_url)

@patch('src.stream_capture.streamlink.streams')
@patch('src.stream_capture.StreamCapture.check_internet_speed')
def test_select_quality_based_on_speed_no_quality(
self, mock_check_internet_speed, mock_streams,
):
mock_check_internet_speed.return_value = (2.0, 1.0)
mock_streams.return_value = {
'720p': MagicMock(url=self.stream_url),
'480p': MagicMock(url=self.stream_url),
}

selected_quality = self.stream_capture.select_quality_based_on_speed()
self.assertEqual(selected_quality, self.stream_url)

@patch('src.stream_capture.cv2.VideoCapture')
@patch('src.stream_capture.streamlink.streams')
@patch('src.stream_capture.StreamCapture.check_internet_speed')
def test_capture_youtube_frames(
self, mock_check_internet_speed, mock_streams, mock_video_capture,
):
mock_check_internet_speed.return_value = (15.0, 5.0)
mock_streams.return_value = {
'best': MagicMock(url=self.stream_url),
'1080p': MagicMock(url=self.stream_url),
'720p': MagicMock(url=self.stream_url),
}
mock_cap_instance = MagicMock()
mock_video_capture.return_value = mock_cap_instance
self.stream_capture.cap = mock_cap_instance

mock_cap_instance.read.side_effect = [
(True, MagicMock()), (True, MagicMock()),
]

frame_generator = self.stream_capture.capture_youtube_frames()

try:
frame, timestamp = next(frame_generator)
self.assertIsNotNone(frame)
self.assertIsInstance(timestamp, float)
except StopIteration:
self.fail('Generator stopped unexpectedly')

mock_now_later = datetime.datetime.now(
) + datetime.timedelta(seconds=self.capture_interval + 1)
with patch(
'src.stream_capture.datetime.datetime.now',
return_value=mock_now_later,
):
frame, timestamp = next(frame_generator)
self.assertIsNotNone(frame)
self.assertIsInstance(timestamp, float)

self.stream_capture.release_resources()

@patch('src.stream_capture.StreamCapture.capture_frames')
def test_execute_capture(self, mock_capture_frames):
expected_timestamp = time.time() # 使用当前时间戳
mock_capture_frames.return_value = iter(
[(MagicMock(), expected_timestamp)],
)
frames_generator = self.stream_capture.execute_capture()

frame, timestamp = next(frames_generator)
self.assertIsNotNone(frame)
self.assertAlmostEqual(
timestamp, expected_timestamp, delta=40,
) # 允许更大的差异

@patch('src.stream_capture.StreamCapture.capture_youtube_frames')
def test_execute_capture_youtube(self, mock_capture_youtube_frames):
self.stream_capture.stream_url = (
'https://www.youtube.com/watch?v=mf-1VZ6ewlE'
)
mock_capture_youtube_frames.return_value = iter(
[(MagicMock(), 1234567890.0)],
)
frames_generator = self.stream_capture.execute_capture()

frame, timestamp = next(frames_generator)
self.assertIsNotNone(frame)
self.assertAlmostEqual(timestamp, 1234567890.0, delta=40)


if __name__ == '__main__':
unittest.main()

0 comments on commit 2ea619e

Please sign in to comment.