From 4d3df8044c87db87a85a711d22cb32d392003a5e Mon Sep 17 00:00:00 2001 From: yihong1120 Date: Tue, 9 Jul 2024 19:31:34 +0800 Subject: [PATCH] Fix TypeError by moving tensors to CPU before converting to numpy --- src/live_stream_tracker.py | 6 +++--- tests/live_stream_tracker_test.py | 13 ++++++++----- 2 files changed, 11 insertions(+), 8 deletions(-) diff --git a/src/live_stream_tracker.py b/src/live_stream_tracker.py index 059efde..d916cd1 100644 --- a/src/live_stream_tracker.py +++ b/src/live_stream_tracker.py @@ -16,7 +16,7 @@ class LiveStreamDetector: def __init__( self, stream_url: str, - model_path: str = '../models/yolov8n.pt', + model_path: str = '../models/pt/best_yolov8n.pt', ): """ Initialise live stream detector with video URL, YOLO model path. @@ -58,12 +58,12 @@ def generate_detections(self) -> Generator[tuple, None, None]: # Convert ids and datas to lists if they are not empty ids_list = ( - ids.numpy().tolist() + ids.cpu().numpy().tolist() if ids is not None and len(ids) > 0 else [] ) datas_list = ( - datas.numpy().tolist() + datas.cpu().numpy().tolist() if datas is not None and len(datas) > 0 else [] ) diff --git a/tests/live_stream_tracker_test.py b/tests/live_stream_tracker_test.py index 46b26e9..92c4c9a 100644 --- a/tests/live_stream_tracker_test.py +++ b/tests/live_stream_tracker_test.py @@ -49,13 +49,16 @@ def test_generate_detections( mock_yolo_instance = MagicMock() mock_yolo.return_value = mock_yolo_instance mock_results = MagicMock() - mock_results[0].boxes = MagicMock() - mock_results[0].boxes.id = MagicMock() - mock_results[0].boxes.data = MagicMock() - mock_results[0].boxes.id.numpy.return_value = [1, 2, 3] - mock_results[0].boxes.data.numpy.return_value = [ + mock_boxes = MagicMock() + mock_boxes.id = MagicMock() + mock_boxes.data = MagicMock() + mock_boxes.id.cpu.return_value = mock_boxes.id + mock_boxes.data.cpu.return_value = mock_boxes.data + mock_boxes.id.numpy.return_value = [1, 2, 3] + mock_boxes.data.numpy.return_value = [ [0.1, 0.2, 0.3, 0.4], [0.5, 0.6, 0.7, 0.8], ] + mock_results[0].boxes = mock_boxes mock_yolo_instance.track.return_value = mock_results mock_now = datetime.datetime(2023, 1, 1, 0, 0, 0)