diff --git a/src/live_stream_tracker.py b/src/live_stream_tracker.py index ca9907e..984ff89 100644 --- a/src/live_stream_tracker.py +++ b/src/live_stream_tracker.py @@ -113,7 +113,10 @@ def run_detection(self): print(datas) -if __name__ == '__main__': +def main(): + """ + Main function to run the live stream detection. + """ parser = argparse.ArgumentParser( description='Perform live stream detection and tracking using YOLOv8.', ) @@ -140,6 +143,10 @@ def run_detection(self): # Release resources after detection is complete detector.release_resources() + +if __name__ == '__main__': + main() + """example python live_stream_tracker.py --url https://cctv6.kctmc.nat.gov.tw/ea05668e/ """ diff --git a/src/stream_viewer.py b/src/stream_viewer.py index 7e3f86a..e7a05c2 100644 --- a/src/stream_viewer.py +++ b/src/stream_viewer.py @@ -56,11 +56,16 @@ def release_resources(self): self.cap.release() cv2.destroyAllWindows() - -if __name__ == '__main__': +def main(): + """ + Main function to run the StreamViewer. + """ # Replace 'vide0_url' with your stream URL. video_url = ( 'https://cctv4.kctmc.nat.gov.tw/50204bfc/' ) viewer = StreamViewer(video_url) viewer.display_stream() + +if __name__ == '__main__': + main() diff --git a/tests/live_stream_tracker_test.py b/tests/live_stream_tracker_test.py index d287992..3788592 100644 --- a/tests/live_stream_tracker_test.py +++ b/tests/live_stream_tracker_test.py @@ -8,6 +8,7 @@ import numpy as np from src.live_stream_tracker import LiveStreamDetector +from src.live_stream_tracker import main class TestLiveStreamDetector(unittest.TestCase): @@ -50,8 +51,10 @@ def test_initialisation( @patch('src.live_stream_tracker.YOLO') @patch('src.live_stream_tracker.cv2.VideoCapture') @patch('src.live_stream_tracker.datetime') + @patch('src.live_stream_tracker.cv2.waitKey', return_value=0xFF & ord('q')) def test_generate_detections( self, + mock_wait_key: MagicMock, mock_datetime: MagicMock, mock_video_capture: MagicMock, mock_yolo: MagicMock, @@ -63,11 +66,9 @@ def test_generate_detections( mock_video_capture.return_value = mock_cap_instance mock_cap_instance.isOpened.side_effect = [True, True, False] mock_cap_instance.read.side_effect = [ - ( - True, np.zeros( - (480, 640, 3), dtype=np.uint8, - ), - ), (True, np.zeros((480, 640, 3), dtype=np.uint8)), (False, None), + (True, np.zeros((480, 640, 3), dtype=np.uint8)), + (True, np.zeros((480, 640, 3), dtype=np.uint8)), + (False, None), ] mock_yolo_instance = MagicMock() @@ -98,6 +99,16 @@ def test_generate_detections( self.assertIsInstance(frame, np.ndarray) self.assertIsInstance(timestamp, float) + try: + ids, datas, frame, timestamp = next(frame_generator) + self.assertEqual(ids, []) + self.assertEqual(datas, []) + self.assertIsInstance(frame, np.ndarray) + self.assertIsInstance(timestamp, float) + except StopIteration: + # Allow StopIteration to pass without failing the test + pass + @patch('src.live_stream_tracker.cv2.VideoCapture') @patch('src.live_stream_tracker.cv2.destroyAllWindows') def test_release_resources( @@ -156,6 +167,29 @@ def test_run_detection(self, mock_generate_detections: MagicMock) -> None: ), ) + @patch('argparse.ArgumentParser.parse_args') + @patch('src.live_stream_tracker.LiveStreamDetector') + def test_main(self, mock_live_stream_detector: MagicMock, mock_parse_args: MagicMock) -> None: + """ + Test case for the main function. + """ + mock_args = MagicMock() + mock_args.url = 'test_url' + mock_args.model = 'test_model' + mock_parse_args.return_value = mock_args + + mock_detector_instance = MagicMock() + mock_live_stream_detector.return_value = mock_detector_instance + + main() + + mock_parse_args.assert_called_once() + mock_live_stream_detector.assert_called_once_with( + 'test_url', 'test_model', + ) + mock_detector_instance.run_detection.assert_called_once() + mock_detector_instance.release_resources.assert_called_once() + if __name__ == '__main__': unittest.main() diff --git a/tests/messenger_notifier_test.py b/tests/messenger_notifier_test.py index e3b36f1..e145a04 100644 --- a/tests/messenger_notifier_test.py +++ b/tests/messenger_notifier_test.py @@ -2,11 +2,10 @@ import unittest from typing import Any -from unittest.mock import patch - +from unittest.mock import patch, MagicMock import numpy as np -from src.messenger_notifier import MessengerNotifier +from src.messenger_notifier import MessengerNotifier, main class TestMessengerNotifier(unittest.TestCase): @@ -92,6 +91,17 @@ def test_missing_page_access_token(self) -> None: ): MessengerNotifier(page_access_token=None) + @patch('src.messenger_notifier.MessengerNotifier.send_notification', return_value=200) + @patch('src.messenger_notifier.os.getenv', return_value='test_page_access_token') + def test_main(self, mock_getenv: MagicMock, mock_send_notification: MagicMock) -> None: + """ + Test the main function. + """ + with patch('builtins.print') as mock_print: + main() + mock_send_notification.assert_called_once() + mock_print.assert_called_with('Response code: 200') + if __name__ == '__main__': unittest.main() diff --git a/tests/stream_viewer_test.py b/tests/stream_viewer_test.py index 7cc60b4..908c752 100644 --- a/tests/stream_viewer_test.py +++ b/tests/stream_viewer_test.py @@ -1,10 +1,9 @@ from __future__ import annotations import unittest -from unittest.mock import MagicMock -from unittest.mock import patch +from unittest.mock import MagicMock, patch -from src.stream_viewer import StreamViewer +from src.stream_viewer import StreamViewer, main class TestStreamViewer(unittest.TestCase): @@ -39,7 +38,7 @@ def test_display_stream( mock_imshow: MagicMock, mock_wait_key: MagicMock, mock_video_capture: MagicMock, - mock_destroyAllWindows: MagicMock, + mock_destroy_all_windows: MagicMock, ) -> None: """ Test the display_stream method for streaming video. @@ -70,7 +69,7 @@ def test_display_stream( self.assertGreaterEqual(mock_cap_instance.read.call_count, 2) # Check if destroyAllWindows was called - mock_destroyAllWindows.assert_called_once() + mock_destroy_all_windows.assert_called_once() @patch('src.stream_viewer.cv2.VideoCapture') @patch('src.stream_viewer.cv2.destroyAllWindows') @@ -96,6 +95,20 @@ def test_release_resources( # Check if destroyAllWindows was called mock_destroy_all_windows.assert_called_once() + @patch('src.stream_viewer.StreamViewer.display_stream') + @patch('src.stream_viewer.StreamViewer.__init__', return_value=None) + def test_main(self, mock_init: MagicMock, mock_display_stream: MagicMock) -> None: + """ + Test the main function. + """ + main() + + # Check if StreamViewer was initialised with the correct URL + mock_init.assert_called_once_with('https://cctv4.kctmc.nat.gov.tw/50204bfc/') + + # Check if display_stream was called + mock_display_stream.assert_called_once() + if __name__ == '__main__': unittest.main() diff --git a/tests/telegram_notifier_test.py b/tests/telegram_notifier_test.py index 1b216af..f2ffbd4 100644 --- a/tests/telegram_notifier_test.py +++ b/tests/telegram_notifier_test.py @@ -2,12 +2,10 @@ import unittest from io import BytesIO -from unittest.mock import AsyncMock -from unittest.mock import patch - +from unittest.mock import AsyncMock, patch import numpy as np -from src.telegram_notifier import TelegramNotifier +from src.telegram_notifier import TelegramNotifier, main class TestTelegramNotifier(unittest.IsolatedAsyncioTestCase): @@ -74,6 +72,26 @@ async def test_send_notification_with_image( self.assertEqual(kwargs['caption'], message) self.assertIsInstance(kwargs['photo'], BytesIO) + @patch('src.telegram_notifier.TelegramNotifier.send_notification', new_callable=AsyncMock) + @patch('src.telegram_notifier.os.getenv') + async def test_main(self, mock_getenv: AsyncMock, mock_send_notification: AsyncMock) -> None: + """ + Test the main function. + """ + mock_getenv.side_effect = lambda key: 'test_bot_token' if key == 'TELEGRAM_BOT_TOKEN' else None + mock_send_notification.return_value = 'Message sent' + + with patch('builtins.print') as mock_print: + await main() + mock_send_notification.assert_called_once() + args, kwargs = mock_send_notification.call_args + self.assertEqual(args[0], 'your_chat_id_here') + self.assertEqual(args[1], 'Hello, Telegram!') + if len(args) > 2: + self.assertIsInstance(args[2], np.ndarray) + self.assertEqual(args[2].shape, (100, 100, 3)) + mock_print.assert_called_once_with('Message sent') + if __name__ == '__main__': unittest.main() diff --git a/tests/wechat_notifier_test.py b/tests/wechat_notifier_test.py index 3ab0a94..abc105e 100644 --- a/tests/wechat_notifier_test.py +++ b/tests/wechat_notifier_test.py @@ -2,13 +2,11 @@ import unittest from io import BytesIO -from unittest.mock import MagicMock -from unittest.mock import patch - +from unittest.mock import MagicMock, patch import numpy as np from PIL import Image -from src.wechat_notifier import WeChatNotifier +from src.wechat_notifier import WeChatNotifier, main class TestWeChatNotifier(unittest.TestCase): @@ -130,6 +128,33 @@ def test_upload_media(self, mock_post: MagicMock) -> None: self.assertIsInstance(kwargs['files']['media'][1], BytesIO) self.assertEqual(kwargs['files']['media'][2], 'image/png') + @patch('src.wechat_notifier.WeChatNotifier.send_notification', return_value={'errcode': 0, 'errmsg': 'ok'}) + @patch('src.wechat_notifier.WeChatNotifier.get_access_token', return_value='test_access_token') + @patch('src.wechat_notifier.os.getenv') + def test_main(self, mock_getenv: MagicMock, mock_get_access_token: MagicMock, mock_send_notification: MagicMock) -> None: + """ + Test the main function. + """ + mock_getenv.side_effect = lambda key: { + 'WECHAT_CORP_ID': 'test_corp_id', + 'WECHAT_CORP_SECRET': 'test_corp_secret', + 'WECHAT_AGENT_ID': '1000002', + }.get(key, '') + + with patch('builtins.print') as mock_print: + main() + mock_send_notification.assert_called_once() + args, kwargs = mock_send_notification.call_args + self.assertEqual(args[0], 'your_user_id_here') + self.assertEqual(args[1], 'Hello, WeChat!') + if len(args) > 2: + self.assertIsInstance(args[2], np.ndarray) + self.assertEqual(args[2].shape, (100, 100, 3)) + mock_print.assert_called() + print_args, print_kwargs = mock_print.call_args + self.assertIn('errcode', print_args[0]) + self.assertIn('errmsg', print_args[0]) + if __name__ == '__main__': unittest.main()