Skip to content

Commit

Permalink
Add test file
Browse files Browse the repository at this point in the history
  • Loading branch information
yihong1120 committed Jul 8, 2024
1 parent 3c9cc99 commit ae5ca2d
Show file tree
Hide file tree
Showing 2 changed files with 146 additions and 51 deletions.
128 changes: 77 additions & 51 deletions src/danger_detector.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,20 @@
from __future__ import annotations
from typing import TypedDict, List, Set

from typing import TypedDict


class BoundingBox(TypedDict):
x1: float
y1: float
x2: float
y2: float


class DetectionData(TypedDict):
bbox: BoundingBox
confidence: float
class_label: float
class_label: int


class DangerDetector:
"""
Expand All @@ -23,7 +27,7 @@ def __init__(self):
"""
pass

def detect_danger(self, datas: List[DetectionData]) -> Set[str]:
def detect_danger(self, datas: list[list[float]]) -> set[str]:
"""
Detects potential safety violations in a construction site.
Expand All @@ -32,34 +36,40 @@ def detect_danger(self, datas: List[DetectionData]) -> Set[str]:
2. Workers dangerously close to machinery or vehicles.
Args:
datas (List[DetectionData]): A list of detections which includes
datas (List[List[float]]): A list of detections which includes
bounding box coordinates, confidence score, and class label.
Returns:
Set[str]: A set of warning messages for safety violations.
List[str]: A list of warning messages for safety violations.
"""
warnings = set() # Initialise the list to store warning messages

# Classify detected objects into different categories
persons = [d for d in datas if d['class_label'] == 5.0] # Persons
hardhat_violations = [d for d in datas if d['class_label'] == 2.0] # No hardhat
safety_vest_violations = [d for d in datas if d['class_label'] == 4.0] # No safety vest
machinery_vehicles = [d for d in datas if d['class_label'] in [8.0, 9.0]] # Machinery and vehicles
persons = [d for d in datas if d[5] == 5] # Persons
hardhat_violations = [d for d in datas if d[5] == 2] # No hardhat
safety_vest_violations = [
d for d in datas if d[5] == 4
] # No safety vest
machinery_vehicles = [
d for d in datas if d[5] in [8, 9]
] # Machinery and vehicles

# Filter out persons who are likely drivers
if machinery_vehicles:
non_drivers = [
p for p in persons if not any(
self.is_driver(p['bbox'], mv['bbox']) for mv in machinery_vehicles
self.is_driver(p[:4], mv[:4]) for mv in machinery_vehicles
)
]
persons = non_drivers

# Check for hardhat and safety vest violations
for violation in hardhat_violations + safety_vest_violations:
label = 'NO-Hardhat' if violation['class_label'] == 2.0 else 'NO-Safety Vest'
label = 'NO-Hardhat' if violation[5] == 2 else 'NO-Safety Vest'
if not any(
self.overlap_percentage(violation['bbox'], p['bbox']) > 0.5 for p in persons
self.overlap_percentage(
violation[:4], p[:4],
) > 0.5 for p in persons
):
warning_msg = (
'警告: 有人無配戴安全帽!' if label == 'NO-Hardhat'
Expand All @@ -70,38 +80,41 @@ def detect_danger(self, datas: List[DetectionData]) -> Set[str]:
# Check if anyone is dangerously close to machinery or vehicles
for person in persons:
for mv in machinery_vehicles:
label = '機具' if mv['class_label'] == 8.0 else '車輛'
if self.is_dangerously_close(person['bbox'], mv['bbox'], label):
label = '機具' if mv[5] == 8 else '車輛'
if self.is_dangerously_close(person[:4], mv[:4], label):
warnings.add(f"警告: 有人過於靠近{label}!")
break

return warnings

@staticmethod
def is_driver(person_bbox: BoundingBox, vehicle_bbox: BoundingBox) -> bool:
def is_driver(
person_bbox: list[float],
vehicle_bbox: list[float],
) -> bool:
"""
Check if a person is a driver based on position near a vehicle.
Args:
person_bbox (BoundingBox): Bounding box of person.
vehicle_bbox (BoundingBox): Bounding box of vehicle.
person_bbox (list[float]): Bounding box of person.
vehicle_bbox (list[float]): Bounding box of vehicle.
Returns:
bool: True if the person is likely the driver, False otherwise.
"""
# Extract coordinates and dimensions of person and vehicle boxes
person_bottom_y = person_bbox['y2']
person_top_y = person_bbox['y1']
person_left_x = person_bbox['x1']
person_right_x = person_bbox['x2']
person_width = person_bbox['x2'] - person_bbox['x1']
person_height = person_bbox['y2'] - person_bbox['y1']

vehicle_top_y = vehicle_bbox['y1']
vehicle_bottom_y = vehicle_bbox['y2']
vehicle_left_x = vehicle_bbox['x1']
vehicle_right_x = vehicle_bbox['x2']
vehicle_height = vehicle_bbox['y2'] - vehicle_bbox['y1']
person_bottom_y = person_bbox[3]
person_top_y = person_bbox[1]
person_left_x = person_bbox[0]
person_right_x = person_bbox[2]
person_width = person_bbox[2] - person_bbox[0]
person_height = person_bbox[3] - person_bbox[1]

vehicle_top_y = vehicle_bbox[1]
vehicle_bottom_y = vehicle_bbox[3]
vehicle_left_x = vehicle_bbox[0]
vehicle_right_x = vehicle_bbox[2]
vehicle_height = vehicle_bbox[3] - vehicle_bbox[1]

# 1. Check vertical bottom position: person's bottom should be above
# the vehicle's bottom by at least half the person's height
Expand Down Expand Up @@ -130,53 +143,62 @@ def is_driver(person_bbox: BoundingBox, vehicle_bbox: BoundingBox) -> bool:
return True

@staticmethod
def overlap_percentage(bbox1: BoundingBox, bbox2: BoundingBox) -> float:
def overlap_percentage(
bbox1: list[float],
bbox2: list[float],
) -> float:
"""
Calculate the overlap percentage between two bounding boxes.
Args:
bbox1 (BoundingBox): The first bounding box.
bbox2 (BoundingBox): The second bounding box.
bbox1 (list[float]): The first bounding box.
bbox2 (list[float]): The second bounding box.
Returns:
float: The overlap percentage.
"""
# Calculate the coordinates of the intersection rectangle
x1 = max(bbox1['x1'], bbox2['x1'])
y1 = max(bbox1['y1'], bbox2['y1'])
x2 = min(bbox1['x2'], bbox2['x2'])
y2 = min(bbox1['y2'], bbox2['y2'])
x1 = max(bbox1[0], bbox2[0])
y1 = max(bbox1[1], bbox2[1])
x2 = min(bbox1[2], bbox2[2])
y2 = min(bbox1[3], bbox2[3])

# Calculate the area of the intersection rectangle
overlap_area = max(0, x2 - x1) * max(0, y2 - y1)

# Calculate the area of both bounding boxes
area1 = (bbox1['x2'] - bbox1['x1']) * (bbox1['y2'] - bbox1['y1'])
area2 = (bbox2['x2'] - bbox2['x1']) * (bbox2['y2'] - bbox2['y1'])
area1 = (bbox1[2] - bbox1[0]) * (bbox1[3] - bbox1[1])
area2 = (bbox2[2] - bbox2[0]) * (bbox2[3] - bbox2[1])

# Calculate the overlap percentage
return overlap_area / float(area1 + area2 - overlap_area)

@staticmethod
def is_dangerously_close(person_bbox: BoundingBox, vehicle_bbox: BoundingBox, label: str) -> bool:
def is_dangerously_close(
person_bbox: list[float],
vehicle_bbox: list[float],
label: str,
) -> bool:
"""
Determine if a person is dangerously close to machinery or vehicles.
Args:
person_bbox (BoundingBox): Bounding box of person.
vehicle_bbox (BoundingBox): Machine/vehicle box.
person_bbox (list[float]): Bounding box of person.
vehicle_bbox (list[float]): Machine/vehicle box.
label (str): Type of the second object ('machinery' or 'vehicle').
Returns:
bool: True if the person is dangerously close, False otherwise.
"""
# Calculate dimensions of the person bounding box
person_width = person_bbox['x2'] - person_bbox['x1']
person_height = person_bbox['y2'] - person_bbox['y1']
person_width = person_bbox[2] - person_bbox[0]
person_height = person_bbox[3] - person_bbox[1]
person_area = person_width * person_height

# Calculate the area of the vehicle bounding box
vehicle_area = (vehicle_bbox['x2'] - vehicle_bbox['x1']) * (vehicle_bbox['y2'] - vehicle_bbox['y1'])
vehicle_area = (vehicle_bbox[2] - vehicle_bbox[0]) * (
vehicle_bbox[3] - vehicle_bbox[1]
)
acceptable_ratio = 0.1 if label == '車輛' else 0.05

# Check if person area ratio is acceptable compared to vehicle area
Expand All @@ -189,12 +211,16 @@ def is_dangerously_close(person_bbox: BoundingBox, vehicle_bbox: BoundingBox, la

# Calculate min horizontal/vertical distance between person and vehicle
horizontal_distance = min(
abs(person_bbox['x2'] - vehicle_bbox['x1']),
abs(person_bbox['x1'] - vehicle_bbox['x2']),
abs(
person_bbox[2] - vehicle_bbox[0],
),
abs(person_bbox[0] - vehicle_bbox[2]),
)
vertical_distance = min(
abs(person_bbox['y2'] - vehicle_bbox['y1']),
abs(person_bbox['y1'] - vehicle_bbox['y2']),
abs(
person_bbox[3] - vehicle_bbox[1],
),
abs(person_bbox[1] - vehicle_bbox[3]),
)

# Determine if the person is dangerously close
Expand All @@ -208,9 +234,9 @@ def is_dangerously_close(person_bbox: BoundingBox, vehicle_bbox: BoundingBox, la
if __name__ == '__main__':
detector = DangerDetector()
data = [
{'bbox': {'x1': 706.87, 'y1': 445.07, 'x2': 976.32, 'y2': 1073.6}, 'confidence': 0.91, 'class_label': 5.0},
{'bbox': {'x1': 0.45513, 'y1': 471.77, 'x2': 662.03, 'y2': 1071.4}, 'confidence': 0.75853, 'class_label': 12.0},
{'bbox': {'x1': 1042.7, 'y1': 638.5, 'x2': 1077.5, 'y2': 731.98}, 'confidence': 0.56060, 'class_label': 18.0},
[706.87, 445.07, 976.32, 1073.6, 3, 0.91],
[0.45513, 471.77, 662.03, 1071.4, 12, 0.75853],
[1042.7, 638.5, 1077.5, 731.98, 18, 0.56060],
]
warnings = detector.detect_danger(data)
for warning in warnings:
Expand Down
69 changes: 69 additions & 0 deletions tests/danger_detector_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
from __future__ import annotations

import unittest

from examples.user_management.models import User
from src.danger_detector import DangerDetector


class TestDangerDetector(unittest.TestCase):
def setUp(self):
self.detector = DangerDetector()

def test_set_password(self):
user = User(username='test_user')
user.set_password('password')
self.assertTrue(user.check_password('password'))
self.assertFalse(user.check_password('wrong_password'))

def test_is_driver(self):
person_bbox = [100, 200, 150, 250]
vehicle_bbox = [50, 100, 200, 300]
self.assertTrue(self.detector.is_driver(person_bbox, vehicle_bbox))

person_bbox = [100, 200, 200, 400]
self.assertFalse(self.detector.is_driver(person_bbox, vehicle_bbox))

def test_overlap_percentage(self):
bbox1 = [100, 100, 200, 200]
bbox2 = [150, 150, 250, 250]
self.assertAlmostEqual(
self.detector.overlap_percentage(
bbox1, bbox2,
), 0.142857, places=6,
)

bbox1 = [100, 100, 200, 200]
bbox2 = [300, 300, 400, 400]
self.assertEqual(self.detector.overlap_percentage(bbox1, bbox2), 0.0)

def test_is_dangerously_close(self):
person_bbox = [100, 100, 120, 120]
vehicle_bbox = [100, 100, 200, 200]
self.assertTrue(
self.detector.is_dangerously_close(
person_bbox, vehicle_bbox, '車輛',
),
)

person_bbox = [0, 0, 10, 10]
vehicle_bbox = [100, 100, 200, 200]
self.assertFalse(
self.detector.is_dangerously_close(
person_bbox, vehicle_bbox, '車輛',
),
)

def test_detect_danger(self):
data = [
[706.87, 445.07, 976.32, 1073.6, 3, 5.0],
[0.45513, 471.77, 662.03, 1071.4, 12, 5.0],
[1042.7, 638.5, 1077.5, 731.98, 18, 2.0],
]
warnings = self.detector.detect_danger(data)
self.assertIn('警告: 有人無配戴安全帽!', warnings)
self.assertNotIn('警告: 有人無穿著安全背心!', warnings)


if __name__ == '__main__':
unittest.main()

0 comments on commit ae5ca2d

Please sign in to comment.