Skip to content

Commit

Permalink
Add test file
Browse files Browse the repository at this point in the history
  • Loading branch information
yihong1120 committed Jul 8, 2024
1 parent 4e42e61 commit 81c5fe4
Show file tree
Hide file tree
Showing 2 changed files with 96 additions and 18 deletions.
35 changes: 17 additions & 18 deletions src/stream_viewer.py
Original file line number Diff line number Diff line change
@@ -1,38 +1,39 @@
from __future__ import annotations

import cv2
from typing import TypedDict

class StreamConfig(TypedDict):
stream_url: str
window_name: str

class StreamViewer:
"""
A class to handle the viewing of video streams (RTSP, HTTP, etc.).
Attributes:
stream_url (str): The URL of the video stream.
window_name (str): The name of the window where the stream will be displayed.
window_name (str): The name of the window where the stream will
be displayed.
"""

def __init__(self, config: StreamConfig):
def __init__(self, stream_url: str, window_name: str = 'Stream Viewer'):
"""
Initialises the StreamViewer instance with a stream URL and a window name.
Initialises the StreamViewer instance with a stream URL
and a window name.
Args:
config (StreamConfig): The configuration for the stream viewer.
stream_url (str): The URL of the video stream.
window_name (str): The name of the window where the stream
will be displayed.
"""
self.stream_url = config['stream_url']
self.window_name = config['window_name']
self.stream_url = stream_url
self.window_name = window_name
self.cap = cv2.VideoCapture(self.stream_url)

def display_stream(self):
"""
Displays the video stream in a window.
Continuously captures frames from the video stream and displays them.
The loop breaks when 'q' is pressed or if the stream cannot be retrieved.
The loop breaks when 'q' is pressed or if the stream cannot be
retrieved.
"""
while True:
# Capture the next frame from the stream.
Expand Down Expand Up @@ -62,11 +63,9 @@ def release_resources(self):


if __name__ == '__main__':
# Define the stream configuration
config: StreamConfig = {
'stream_url': 'https://cctv4.kctmc.nat.gov.tw/50204bfc/',
'window_name': 'Stream Viewer'
}

viewer = StreamViewer(config)
# Replace 'vide0_url' with your stream URL.
video_url = (
'https://cctv4.kctmc.nat.gov.tw/50204bfc/'
)
viewer = StreamViewer(video_url)
viewer.display_stream()
79 changes: 79 additions & 0 deletions tests/stream_viewer_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
from __future__ import annotations

import unittest
from unittest.mock import MagicMock
from unittest.mock import patch

from src.stream_viewer import StreamViewer


class TestStreamViewer(unittest.TestCase):

@patch('src.stream_viewer.cv2.VideoCapture')
def test_initialization(self, mock_video_capture):
# Initialize StreamViewer with a test URL
stream_url = 'https://example.com/stream'
viewer = StreamViewer(stream_url)

# Check if the URL is set correctly
self.assertEqual(viewer.stream_url, stream_url)

# Check if the window name is set correctly
self.assertEqual(viewer.window_name, 'Stream Viewer')

# Check if VideoCapture was called with the correct URL
mock_video_capture.assert_called_once_with(stream_url)

@patch('src.stream_viewer.cv2.VideoCapture')
@patch('src.stream_viewer.cv2.imshow')
@patch('src.stream_viewer.cv2.waitKey')
def test_display_stream(
self, mock_wait_key, mock_imshow, mock_video_capture,
):
# Mock VideoCapture instance
mock_cap_instance = MagicMock()
mock_video_capture.return_value = mock_cap_instance

# Simulate read() returning True with a dummy frame
mock_cap_instance.read.side_effect = [
(True, MagicMock()), (True, MagicMock()), (False, None),
]

# Simulate waitKey() returning 'q' to break the loop
mock_wait_key.side_effect = [ord('a'), ord('b'), ord('q')]

# Initialize StreamViewer and call display_stream
viewer = StreamViewer('https://example.com/stream')
viewer.display_stream()

# Check if imshow was called correctly
self.assertEqual(mock_imshow.call_count, 2)

# Check if waitKey was called at least twice
self.assertGreaterEqual(mock_wait_key.call_count, 2)

# Check if read was called at least twice
self.assertGreaterEqual(mock_cap_instance.read.call_count, 2)

@patch('src.stream_viewer.cv2.VideoCapture')
@patch('src.stream_viewer.cv2.destroyAllWindows')
def test_release_resources(
self, mock_destroy_all_windows, mock_video_capture,
):
# Mock VideoCapture instance
mock_cap_instance = MagicMock()
mock_video_capture.return_value = mock_cap_instance

# Initialize StreamViewer and call release_resources
viewer = StreamViewer('https://example.com/stream')
viewer.release_resources()

# Check if release was called on VideoCapture instance
mock_cap_instance.release.assert_called_once()

# Check if destroyAllWindows was called
mock_destroy_all_windows.assert_called_once()


if __name__ == '__main__':
unittest.main()

0 comments on commit 81c5fe4

Please sign in to comment.