From ae5ca2dd55eea9203f7e531b495f4007786f11fd Mon Sep 17 00:00:00 2001 From: yihong1120 Date: Mon, 8 Jul 2024 19:57:28 +0800 Subject: [PATCH] Add test file --- src/danger_detector.py | 128 ++++++++++++++++++++-------------- tests/danger_detector_test.py | 69 ++++++++++++++++++ 2 files changed, 146 insertions(+), 51 deletions(-) create mode 100644 tests/danger_detector_test.py diff --git a/src/danger_detector.py b/src/danger_detector.py index 254a5e9..ee61c73 100644 --- a/src/danger_detector.py +++ b/src/danger_detector.py @@ -1,5 +1,7 @@ from __future__ import annotations -from typing import TypedDict, List, Set + +from typing import TypedDict + class BoundingBox(TypedDict): x1: float @@ -7,10 +9,12 @@ class BoundingBox(TypedDict): x2: float y2: float + class DetectionData(TypedDict): bbox: BoundingBox confidence: float - class_label: float + class_label: int + class DangerDetector: """ @@ -23,7 +27,7 @@ def __init__(self): """ pass - def detect_danger(self, datas: List[DetectionData]) -> Set[str]: + def detect_danger(self, datas: list[list[float]]) -> set[str]: """ Detects potential safety violations in a construction site. @@ -32,34 +36,40 @@ def detect_danger(self, datas: List[DetectionData]) -> Set[str]: 2. Workers dangerously close to machinery or vehicles. Args: - datas (List[DetectionData]): A list of detections which includes + datas (List[List[float]]): A list of detections which includes bounding box coordinates, confidence score, and class label. Returns: - Set[str]: A set of warning messages for safety violations. + List[str]: A list of warning messages for safety violations. """ warnings = set() # Initialise the list to store warning messages # Classify detected objects into different categories - persons = [d for d in datas if d['class_label'] == 5.0] # Persons - hardhat_violations = [d for d in datas if d['class_label'] == 2.0] # No hardhat - safety_vest_violations = [d for d in datas if d['class_label'] == 4.0] # No safety vest - machinery_vehicles = [d for d in datas if d['class_label'] in [8.0, 9.0]] # Machinery and vehicles + persons = [d for d in datas if d[5] == 5] # Persons + hardhat_violations = [d for d in datas if d[5] == 2] # No hardhat + safety_vest_violations = [ + d for d in datas if d[5] == 4 + ] # No safety vest + machinery_vehicles = [ + d for d in datas if d[5] in [8, 9] + ] # Machinery and vehicles # Filter out persons who are likely drivers if machinery_vehicles: non_drivers = [ p for p in persons if not any( - self.is_driver(p['bbox'], mv['bbox']) for mv in machinery_vehicles + self.is_driver(p[:4], mv[:4]) for mv in machinery_vehicles ) ] persons = non_drivers # Check for hardhat and safety vest violations for violation in hardhat_violations + safety_vest_violations: - label = 'NO-Hardhat' if violation['class_label'] == 2.0 else 'NO-Safety Vest' + label = 'NO-Hardhat' if violation[5] == 2 else 'NO-Safety Vest' if not any( - self.overlap_percentage(violation['bbox'], p['bbox']) > 0.5 for p in persons + self.overlap_percentage( + violation[:4], p[:4], + ) > 0.5 for p in persons ): warning_msg = ( '警告: 有人無配戴安全帽!' if label == 'NO-Hardhat' @@ -70,38 +80,41 @@ def detect_danger(self, datas: List[DetectionData]) -> Set[str]: # Check if anyone is dangerously close to machinery or vehicles for person in persons: for mv in machinery_vehicles: - label = '機具' if mv['class_label'] == 8.0 else '車輛' - if self.is_dangerously_close(person['bbox'], mv['bbox'], label): + label = '機具' if mv[5] == 8 else '車輛' + if self.is_dangerously_close(person[:4], mv[:4], label): warnings.add(f"警告: 有人過於靠近{label}!") break return warnings @staticmethod - def is_driver(person_bbox: BoundingBox, vehicle_bbox: BoundingBox) -> bool: + def is_driver( + person_bbox: list[float], + vehicle_bbox: list[float], + ) -> bool: """ Check if a person is a driver based on position near a vehicle. Args: - person_bbox (BoundingBox): Bounding box of person. - vehicle_bbox (BoundingBox): Bounding box of vehicle. + person_bbox (list[float]): Bounding box of person. + vehicle_bbox (list[float]): Bounding box of vehicle. Returns: bool: True if the person is likely the driver, False otherwise. """ # Extract coordinates and dimensions of person and vehicle boxes - person_bottom_y = person_bbox['y2'] - person_top_y = person_bbox['y1'] - person_left_x = person_bbox['x1'] - person_right_x = person_bbox['x2'] - person_width = person_bbox['x2'] - person_bbox['x1'] - person_height = person_bbox['y2'] - person_bbox['y1'] - - vehicle_top_y = vehicle_bbox['y1'] - vehicle_bottom_y = vehicle_bbox['y2'] - vehicle_left_x = vehicle_bbox['x1'] - vehicle_right_x = vehicle_bbox['x2'] - vehicle_height = vehicle_bbox['y2'] - vehicle_bbox['y1'] + person_bottom_y = person_bbox[3] + person_top_y = person_bbox[1] + person_left_x = person_bbox[0] + person_right_x = person_bbox[2] + person_width = person_bbox[2] - person_bbox[0] + person_height = person_bbox[3] - person_bbox[1] + + vehicle_top_y = vehicle_bbox[1] + vehicle_bottom_y = vehicle_bbox[3] + vehicle_left_x = vehicle_bbox[0] + vehicle_right_x = vehicle_bbox[2] + vehicle_height = vehicle_bbox[3] - vehicle_bbox[1] # 1. Check vertical bottom position: person's bottom should be above # the vehicle's bottom by at least half the person's height @@ -130,53 +143,62 @@ def is_driver(person_bbox: BoundingBox, vehicle_bbox: BoundingBox) -> bool: return True @staticmethod - def overlap_percentage(bbox1: BoundingBox, bbox2: BoundingBox) -> float: + def overlap_percentage( + bbox1: list[float], + bbox2: list[float], + ) -> float: """ Calculate the overlap percentage between two bounding boxes. Args: - bbox1 (BoundingBox): The first bounding box. - bbox2 (BoundingBox): The second bounding box. + bbox1 (list[float]): The first bounding box. + bbox2 (list[float]): The second bounding box. Returns: float: The overlap percentage. """ # Calculate the coordinates of the intersection rectangle - x1 = max(bbox1['x1'], bbox2['x1']) - y1 = max(bbox1['y1'], bbox2['y1']) - x2 = min(bbox1['x2'], bbox2['x2']) - y2 = min(bbox1['y2'], bbox2['y2']) + x1 = max(bbox1[0], bbox2[0]) + y1 = max(bbox1[1], bbox2[1]) + x2 = min(bbox1[2], bbox2[2]) + y2 = min(bbox1[3], bbox2[3]) # Calculate the area of the intersection rectangle overlap_area = max(0, x2 - x1) * max(0, y2 - y1) # Calculate the area of both bounding boxes - area1 = (bbox1['x2'] - bbox1['x1']) * (bbox1['y2'] - bbox1['y1']) - area2 = (bbox2['x2'] - bbox2['x1']) * (bbox2['y2'] - bbox2['y1']) + area1 = (bbox1[2] - bbox1[0]) * (bbox1[3] - bbox1[1]) + area2 = (bbox2[2] - bbox2[0]) * (bbox2[3] - bbox2[1]) # Calculate the overlap percentage return overlap_area / float(area1 + area2 - overlap_area) @staticmethod - def is_dangerously_close(person_bbox: BoundingBox, vehicle_bbox: BoundingBox, label: str) -> bool: + def is_dangerously_close( + person_bbox: list[float], + vehicle_bbox: list[float], + label: str, + ) -> bool: """ Determine if a person is dangerously close to machinery or vehicles. Args: - person_bbox (BoundingBox): Bounding box of person. - vehicle_bbox (BoundingBox): Machine/vehicle box. + person_bbox (list[float]): Bounding box of person. + vehicle_bbox (list[float]): Machine/vehicle box. label (str): Type of the second object ('machinery' or 'vehicle'). Returns: bool: True if the person is dangerously close, False otherwise. """ # Calculate dimensions of the person bounding box - person_width = person_bbox['x2'] - person_bbox['x1'] - person_height = person_bbox['y2'] - person_bbox['y1'] + person_width = person_bbox[2] - person_bbox[0] + person_height = person_bbox[3] - person_bbox[1] person_area = person_width * person_height # Calculate the area of the vehicle bounding box - vehicle_area = (vehicle_bbox['x2'] - vehicle_bbox['x1']) * (vehicle_bbox['y2'] - vehicle_bbox['y1']) + vehicle_area = (vehicle_bbox[2] - vehicle_bbox[0]) * ( + vehicle_bbox[3] - vehicle_bbox[1] + ) acceptable_ratio = 0.1 if label == '車輛' else 0.05 # Check if person area ratio is acceptable compared to vehicle area @@ -189,12 +211,16 @@ def is_dangerously_close(person_bbox: BoundingBox, vehicle_bbox: BoundingBox, la # Calculate min horizontal/vertical distance between person and vehicle horizontal_distance = min( - abs(person_bbox['x2'] - vehicle_bbox['x1']), - abs(person_bbox['x1'] - vehicle_bbox['x2']), + abs( + person_bbox[2] - vehicle_bbox[0], + ), + abs(person_bbox[0] - vehicle_bbox[2]), ) vertical_distance = min( - abs(person_bbox['y2'] - vehicle_bbox['y1']), - abs(person_bbox['y1'] - vehicle_bbox['y2']), + abs( + person_bbox[3] - vehicle_bbox[1], + ), + abs(person_bbox[1] - vehicle_bbox[3]), ) # Determine if the person is dangerously close @@ -208,9 +234,9 @@ def is_dangerously_close(person_bbox: BoundingBox, vehicle_bbox: BoundingBox, la if __name__ == '__main__': detector = DangerDetector() data = [ - {'bbox': {'x1': 706.87, 'y1': 445.07, 'x2': 976.32, 'y2': 1073.6}, 'confidence': 0.91, 'class_label': 5.0}, - {'bbox': {'x1': 0.45513, 'y1': 471.77, 'x2': 662.03, 'y2': 1071.4}, 'confidence': 0.75853, 'class_label': 12.0}, - {'bbox': {'x1': 1042.7, 'y1': 638.5, 'x2': 1077.5, 'y2': 731.98}, 'confidence': 0.56060, 'class_label': 18.0}, + [706.87, 445.07, 976.32, 1073.6, 3, 0.91], + [0.45513, 471.77, 662.03, 1071.4, 12, 0.75853], + [1042.7, 638.5, 1077.5, 731.98, 18, 0.56060], ] warnings = detector.detect_danger(data) for warning in warnings: diff --git a/tests/danger_detector_test.py b/tests/danger_detector_test.py new file mode 100644 index 0000000..7da8403 --- /dev/null +++ b/tests/danger_detector_test.py @@ -0,0 +1,69 @@ +from __future__ import annotations + +import unittest + +from examples.user_management.models import User +from src.danger_detector import DangerDetector + + +class TestDangerDetector(unittest.TestCase): + def setUp(self): + self.detector = DangerDetector() + + def test_set_password(self): + user = User(username='test_user') + user.set_password('password') + self.assertTrue(user.check_password('password')) + self.assertFalse(user.check_password('wrong_password')) + + def test_is_driver(self): + person_bbox = [100, 200, 150, 250] + vehicle_bbox = [50, 100, 200, 300] + self.assertTrue(self.detector.is_driver(person_bbox, vehicle_bbox)) + + person_bbox = [100, 200, 200, 400] + self.assertFalse(self.detector.is_driver(person_bbox, vehicle_bbox)) + + def test_overlap_percentage(self): + bbox1 = [100, 100, 200, 200] + bbox2 = [150, 150, 250, 250] + self.assertAlmostEqual( + self.detector.overlap_percentage( + bbox1, bbox2, + ), 0.142857, places=6, + ) + + bbox1 = [100, 100, 200, 200] + bbox2 = [300, 300, 400, 400] + self.assertEqual(self.detector.overlap_percentage(bbox1, bbox2), 0.0) + + def test_is_dangerously_close(self): + person_bbox = [100, 100, 120, 120] + vehicle_bbox = [100, 100, 200, 200] + self.assertTrue( + self.detector.is_dangerously_close( + person_bbox, vehicle_bbox, '車輛', + ), + ) + + person_bbox = [0, 0, 10, 10] + vehicle_bbox = [100, 100, 200, 200] + self.assertFalse( + self.detector.is_dangerously_close( + person_bbox, vehicle_bbox, '車輛', + ), + ) + + def test_detect_danger(self): + data = [ + [706.87, 445.07, 976.32, 1073.6, 3, 5.0], + [0.45513, 471.77, 662.03, 1071.4, 12, 5.0], + [1042.7, 638.5, 1077.5, 731.98, 18, 2.0], + ] + warnings = self.detector.detect_danger(data) + self.assertIn('警告: 有人無配戴安全帽!', warnings) + self.assertNotIn('警告: 有人無穿著安全背心!', warnings) + + +if __name__ == '__main__': + unittest.main()