Skip to content

Commit

Permalink
Boost test coverage
Browse files Browse the repository at this point in the history
  • Loading branch information
yihong1120 committed Aug 7, 2024
1 parent 3fe9ed8 commit fe0459d
Show file tree
Hide file tree
Showing 7 changed files with 136 additions and 24 deletions.
9 changes: 8 additions & 1 deletion src/live_stream_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,10 @@ def run_detection(self):
print(datas)


if __name__ == '__main__':
def main():
"""
Main function to run the live stream detection.
"""
parser = argparse.ArgumentParser(
description='Perform live stream detection and tracking using YOLOv8.',
)
Expand All @@ -140,6 +143,10 @@ def run_detection(self):
# Release resources after detection is complete
detector.release_resources()


if __name__ == '__main__':
main()

"""example
python live_stream_tracker.py --url https://cctv6.kctmc.nat.gov.tw/ea05668e/
"""
9 changes: 7 additions & 2 deletions src/stream_viewer.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,16 @@ def release_resources(self):
self.cap.release()
cv2.destroyAllWindows()


if __name__ == '__main__':
def main():
"""
Main function to run the StreamViewer.
"""
# Replace 'vide0_url' with your stream URL.
video_url = (
'https://cctv4.kctmc.nat.gov.tw/50204bfc/'
)
viewer = StreamViewer(video_url)
viewer.display_stream()

if __name__ == '__main__':
main()
44 changes: 39 additions & 5 deletions tests/live_stream_tracker_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import numpy as np

from src.live_stream_tracker import LiveStreamDetector
from src.live_stream_tracker import main


class TestLiveStreamDetector(unittest.TestCase):
Expand Down Expand Up @@ -50,8 +51,10 @@ def test_initialisation(
@patch('src.live_stream_tracker.YOLO')
@patch('src.live_stream_tracker.cv2.VideoCapture')
@patch('src.live_stream_tracker.datetime')
@patch('src.live_stream_tracker.cv2.waitKey', return_value=0xFF & ord('q'))
def test_generate_detections(
self,
mock_wait_key: MagicMock,
mock_datetime: MagicMock,
mock_video_capture: MagicMock,
mock_yolo: MagicMock,
Expand All @@ -63,11 +66,9 @@ def test_generate_detections(
mock_video_capture.return_value = mock_cap_instance
mock_cap_instance.isOpened.side_effect = [True, True, False]
mock_cap_instance.read.side_effect = [
(
True, np.zeros(
(480, 640, 3), dtype=np.uint8,
),
), (True, np.zeros((480, 640, 3), dtype=np.uint8)), (False, None),
(True, np.zeros((480, 640, 3), dtype=np.uint8)),
(True, np.zeros((480, 640, 3), dtype=np.uint8)),
(False, None),
]

mock_yolo_instance = MagicMock()
Expand Down Expand Up @@ -98,6 +99,16 @@ def test_generate_detections(
self.assertIsInstance(frame, np.ndarray)
self.assertIsInstance(timestamp, float)

try:
ids, datas, frame, timestamp = next(frame_generator)
self.assertEqual(ids, [])
self.assertEqual(datas, [])
self.assertIsInstance(frame, np.ndarray)
self.assertIsInstance(timestamp, float)
except StopIteration:
# Allow StopIteration to pass without failing the test
pass

@patch('src.live_stream_tracker.cv2.VideoCapture')
@patch('src.live_stream_tracker.cv2.destroyAllWindows')
def test_release_resources(
Expand Down Expand Up @@ -156,6 +167,29 @@ def test_run_detection(self, mock_generate_detections: MagicMock) -> None:
),
)

@patch('argparse.ArgumentParser.parse_args')
@patch('src.live_stream_tracker.LiveStreamDetector')
def test_main(self, mock_live_stream_detector: MagicMock, mock_parse_args: MagicMock) -> None:
"""
Test case for the main function.
"""
mock_args = MagicMock()
mock_args.url = 'test_url'
mock_args.model = 'test_model'
mock_parse_args.return_value = mock_args

mock_detector_instance = MagicMock()
mock_live_stream_detector.return_value = mock_detector_instance

main()

mock_parse_args.assert_called_once()
mock_live_stream_detector.assert_called_once_with(
'test_url', 'test_model',
)
mock_detector_instance.run_detection.assert_called_once()
mock_detector_instance.release_resources.assert_called_once()


if __name__ == '__main__':
unittest.main()
16 changes: 13 additions & 3 deletions tests/messenger_notifier_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,10 @@

import unittest
from typing import Any
from unittest.mock import patch

from unittest.mock import patch, MagicMock
import numpy as np

from src.messenger_notifier import MessengerNotifier
from src.messenger_notifier import MessengerNotifier, main


class TestMessengerNotifier(unittest.TestCase):
Expand Down Expand Up @@ -92,6 +91,17 @@ def test_missing_page_access_token(self) -> None:
):
MessengerNotifier(page_access_token=None)

@patch('src.messenger_notifier.MessengerNotifier.send_notification', return_value=200)
@patch('src.messenger_notifier.os.getenv', return_value='test_page_access_token')
def test_main(self, mock_getenv: MagicMock, mock_send_notification: MagicMock) -> None:
"""
Test the main function.
"""
with patch('builtins.print') as mock_print:
main()
mock_send_notification.assert_called_once()
mock_print.assert_called_with('Response code: 200')


if __name__ == '__main__':
unittest.main()
23 changes: 18 additions & 5 deletions tests/stream_viewer_test.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
from __future__ import annotations

import unittest
from unittest.mock import MagicMock
from unittest.mock import patch
from unittest.mock import MagicMock, patch

from src.stream_viewer import StreamViewer
from src.stream_viewer import StreamViewer, main


class TestStreamViewer(unittest.TestCase):
Expand Down Expand Up @@ -39,7 +38,7 @@ def test_display_stream(
mock_imshow: MagicMock,
mock_wait_key: MagicMock,
mock_video_capture: MagicMock,
mock_destroyAllWindows: MagicMock,
mock_destroy_all_windows: MagicMock,
) -> None:
"""
Test the display_stream method for streaming video.
Expand Down Expand Up @@ -70,7 +69,7 @@ def test_display_stream(
self.assertGreaterEqual(mock_cap_instance.read.call_count, 2)

# Check if destroyAllWindows was called
mock_destroyAllWindows.assert_called_once()
mock_destroy_all_windows.assert_called_once()

@patch('src.stream_viewer.cv2.VideoCapture')
@patch('src.stream_viewer.cv2.destroyAllWindows')
Expand All @@ -96,6 +95,20 @@ def test_release_resources(
# Check if destroyAllWindows was called
mock_destroy_all_windows.assert_called_once()

@patch('src.stream_viewer.StreamViewer.display_stream')
@patch('src.stream_viewer.StreamViewer.__init__', return_value=None)
def test_main(self, mock_init: MagicMock, mock_display_stream: MagicMock) -> None:
"""
Test the main function.
"""
main()

# Check if StreamViewer was initialised with the correct URL
mock_init.assert_called_once_with('https://cctv4.kctmc.nat.gov.tw/50204bfc/')

# Check if display_stream was called
mock_display_stream.assert_called_once()


if __name__ == '__main__':
unittest.main()
26 changes: 22 additions & 4 deletions tests/telegram_notifier_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,10 @@

import unittest
from io import BytesIO
from unittest.mock import AsyncMock
from unittest.mock import patch

from unittest.mock import AsyncMock, patch
import numpy as np

from src.telegram_notifier import TelegramNotifier
from src.telegram_notifier import TelegramNotifier, main


class TestTelegramNotifier(unittest.IsolatedAsyncioTestCase):
Expand Down Expand Up @@ -74,6 +72,26 @@ async def test_send_notification_with_image(
self.assertEqual(kwargs['caption'], message)
self.assertIsInstance(kwargs['photo'], BytesIO)

@patch('src.telegram_notifier.TelegramNotifier.send_notification', new_callable=AsyncMock)
@patch('src.telegram_notifier.os.getenv')
async def test_main(self, mock_getenv: AsyncMock, mock_send_notification: AsyncMock) -> None:
"""
Test the main function.
"""
mock_getenv.side_effect = lambda key: 'test_bot_token' if key == 'TELEGRAM_BOT_TOKEN' else None
mock_send_notification.return_value = 'Message sent'

with patch('builtins.print') as mock_print:
await main()
mock_send_notification.assert_called_once()
args, kwargs = mock_send_notification.call_args
self.assertEqual(args[0], 'your_chat_id_here')
self.assertEqual(args[1], 'Hello, Telegram!')
if len(args) > 2:
self.assertIsInstance(args[2], np.ndarray)
self.assertEqual(args[2].shape, (100, 100, 3))
mock_print.assert_called_once_with('Message sent')


if __name__ == '__main__':
unittest.main()
33 changes: 29 additions & 4 deletions tests/wechat_notifier_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,11 @@

import unittest
from io import BytesIO
from unittest.mock import MagicMock
from unittest.mock import patch

from unittest.mock import MagicMock, patch
import numpy as np
from PIL import Image

from src.wechat_notifier import WeChatNotifier
from src.wechat_notifier import WeChatNotifier, main


class TestWeChatNotifier(unittest.TestCase):
Expand Down Expand Up @@ -130,6 +128,33 @@ def test_upload_media(self, mock_post: MagicMock) -> None:
self.assertIsInstance(kwargs['files']['media'][1], BytesIO)
self.assertEqual(kwargs['files']['media'][2], 'image/png')

@patch('src.wechat_notifier.WeChatNotifier.send_notification', return_value={'errcode': 0, 'errmsg': 'ok'})
@patch('src.wechat_notifier.WeChatNotifier.get_access_token', return_value='test_access_token')
@patch('src.wechat_notifier.os.getenv')
def test_main(self, mock_getenv: MagicMock, mock_get_access_token: MagicMock, mock_send_notification: MagicMock) -> None:
"""
Test the main function.
"""
mock_getenv.side_effect = lambda key: {
'WECHAT_CORP_ID': 'test_corp_id',
'WECHAT_CORP_SECRET': 'test_corp_secret',
'WECHAT_AGENT_ID': '1000002',
}.get(key, '')

with patch('builtins.print') as mock_print:
main()
mock_send_notification.assert_called_once()
args, kwargs = mock_send_notification.call_args
self.assertEqual(args[0], 'your_user_id_here')
self.assertEqual(args[1], 'Hello, WeChat!')
if len(args) > 2:
self.assertIsInstance(args[2], np.ndarray)
self.assertEqual(args[2].shape, (100, 100, 3))
mock_print.assert_called()
print_args, print_kwargs = mock_print.call_args
self.assertIn('errcode', print_args[0])
self.assertIn('errmsg', print_args[0])


if __name__ == '__main__':
unittest.main()

0 comments on commit fe0459d

Please sign in to comment.