diff --git a/src/stream_viewer.py b/src/stream_viewer.py index 74d0eba..20f9b23 100644 --- a/src/stream_viewer.py +++ b/src/stream_viewer.py @@ -1,11 +1,7 @@ from __future__ import annotations import cv2 -from typing import TypedDict -class StreamConfig(TypedDict): - stream_url: str - window_name: str class StreamViewer: """ @@ -13,18 +9,22 @@ class StreamViewer: Attributes: stream_url (str): The URL of the video stream. - window_name (str): The name of the window where the stream will be displayed. + window_name (str): The name of the window where the stream will + be displayed. """ - def __init__(self, config: StreamConfig): + def __init__(self, stream_url: str, window_name: str = 'Stream Viewer'): """ - Initialises the StreamViewer instance with a stream URL and a window name. + Initialises the StreamViewer instance with a stream URL + and a window name. Args: - config (StreamConfig): The configuration for the stream viewer. + stream_url (str): The URL of the video stream. + window_name (str): The name of the window where the stream + will be displayed. """ - self.stream_url = config['stream_url'] - self.window_name = config['window_name'] + self.stream_url = stream_url + self.window_name = window_name self.cap = cv2.VideoCapture(self.stream_url) def display_stream(self): @@ -32,7 +32,8 @@ def display_stream(self): Displays the video stream in a window. Continuously captures frames from the video stream and displays them. - The loop breaks when 'q' is pressed or if the stream cannot be retrieved. + The loop breaks when 'q' is pressed or if the stream cannot be + retrieved. """ while True: # Capture the next frame from the stream. @@ -62,11 +63,9 @@ def release_resources(self): if __name__ == '__main__': - # Define the stream configuration - config: StreamConfig = { - 'stream_url': 'https://cctv4.kctmc.nat.gov.tw/50204bfc/', - 'window_name': 'Stream Viewer' - } - - viewer = StreamViewer(config) + # Replace 'vide0_url' with your stream URL. + video_url = ( + 'https://cctv4.kctmc.nat.gov.tw/50204bfc/' + ) + viewer = StreamViewer(video_url) viewer.display_stream() diff --git a/tests/stream_viewer_test.py b/tests/stream_viewer_test.py new file mode 100644 index 0000000..5f8ad8f --- /dev/null +++ b/tests/stream_viewer_test.py @@ -0,0 +1,79 @@ +from __future__ import annotations + +import unittest +from unittest.mock import MagicMock +from unittest.mock import patch + +from src.stream_viewer import StreamViewer + + +class TestStreamViewer(unittest.TestCase): + + @patch('src.stream_viewer.cv2.VideoCapture') + def test_initialization(self, mock_video_capture): + # Initialize StreamViewer with a test URL + stream_url = 'https://example.com/stream' + viewer = StreamViewer(stream_url) + + # Check if the URL is set correctly + self.assertEqual(viewer.stream_url, stream_url) + + # Check if the window name is set correctly + self.assertEqual(viewer.window_name, 'Stream Viewer') + + # Check if VideoCapture was called with the correct URL + mock_video_capture.assert_called_once_with(stream_url) + + @patch('src.stream_viewer.cv2.VideoCapture') + @patch('src.stream_viewer.cv2.imshow') + @patch('src.stream_viewer.cv2.waitKey') + def test_display_stream( + self, mock_wait_key, mock_imshow, mock_video_capture, + ): + # Mock VideoCapture instance + mock_cap_instance = MagicMock() + mock_video_capture.return_value = mock_cap_instance + + # Simulate read() returning True with a dummy frame + mock_cap_instance.read.side_effect = [ + (True, MagicMock()), (True, MagicMock()), (False, None), + ] + + # Simulate waitKey() returning 'q' to break the loop + mock_wait_key.side_effect = [ord('a'), ord('b'), ord('q')] + + # Initialize StreamViewer and call display_stream + viewer = StreamViewer('https://example.com/stream') + viewer.display_stream() + + # Check if imshow was called correctly + self.assertEqual(mock_imshow.call_count, 2) + + # Check if waitKey was called at least twice + self.assertGreaterEqual(mock_wait_key.call_count, 2) + + # Check if read was called at least twice + self.assertGreaterEqual(mock_cap_instance.read.call_count, 2) + + @patch('src.stream_viewer.cv2.VideoCapture') + @patch('src.stream_viewer.cv2.destroyAllWindows') + def test_release_resources( + self, mock_destroy_all_windows, mock_video_capture, + ): + # Mock VideoCapture instance + mock_cap_instance = MagicMock() + mock_video_capture.return_value = mock_cap_instance + + # Initialize StreamViewer and call release_resources + viewer = StreamViewer('https://example.com/stream') + viewer.release_resources() + + # Check if release was called on VideoCapture instance + mock_cap_instance.release.assert_called_once() + + # Check if destroyAllWindows was called + mock_destroy_all_windows.assert_called_once() + + +if __name__ == '__main__': + unittest.main()