Skip to content

Commit

Permalink
Debug and boost test coverage
Browse files Browse the repository at this point in the history
  • Loading branch information
yihong1120 committed Aug 6, 2024
1 parent b2d626b commit ca286f0
Show file tree
Hide file tree
Showing 9 changed files with 570 additions and 75 deletions.
46 changes: 39 additions & 7 deletions src/danger_detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
from shapely.geometry import Point
from shapely.geometry import Polygon
from sklearn.cluster import HDBSCAN
# from hdbscan import HDBSCAN


class DangerDetector:
Expand All @@ -20,6 +19,36 @@ def __init__(self):
# Initialise the HDBSCAN clusterer
self.clusterer = HDBSCAN(min_samples=3, min_cluster_size=2)

def normalise_bbox(self, bbox):
"""
Normalises the bounding box coordinates.
Args:
bbox (list[float]): The bounding box coordinates.
Returns:
list[float]: Normalised coordinates.
"""
left_x = min(bbox[0], bbox[2])
right_x = max(bbox[0], bbox[2])
top_y = min(bbox[1], bbox[3])
bottom_y = max(bbox[1], bbox[3])
if len(bbox) > 4:
return [left_x, top_y, right_x, bottom_y, bbox[4], bbox[5]]
return [left_x, top_y, right_x, bottom_y]

def normalise_data(self, datas):
"""
Normalises a list of bounding box data.
Args:
datas (list[list[float]]): List of bounding box data.
Returns:
list[list[float]]: Normalised data.
"""
return [self.normalise_bbox(data[:4] + data[4:]) for data in datas]

def detect_polygon_from_cones(
self,
datas: list[list[float]],
Expand Down Expand Up @@ -131,6 +160,9 @@ def detect_danger(
"""
warnings = set() # Initialise the list to store warning messages

# Normalise data
datas = self.normalise_data(datas)

# Check if people are entering the controlled area
polygons = self.detect_polygon_from_cones(datas)
people_count = self.calculate_people_in_controlled_area(
Expand Down Expand Up @@ -184,10 +216,7 @@ def detect_danger(
return warnings, polygons

@staticmethod
def is_driver(
person_bbox: list[float],
vehicle_bbox: list[float],
) -> bool:
def is_driver(person_bbox: list[float], vehicle_bbox: list[float]) -> bool:
"""
Check if a person is a driver based on position near a vehicle.
Expand Down Expand Up @@ -322,8 +351,7 @@ def is_dangerously_close(
)


# Example usage
if __name__ == '__main__':
def main():
detector = DangerDetector()

data: list[list[float]] = [
Expand Down Expand Up @@ -351,3 +379,7 @@ def is_dangerously_close(
warnings, polygons = detector.detect_danger(data)
for warning in warnings:
print(warning)


if __name__ == '__main__':
main()
17 changes: 11 additions & 6 deletions src/drawing_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,8 +186,12 @@ def save_frame(self, frame_bytes: bytearray, output_filename: str) -> None:
gc.collect()


if __name__ == '__main__':
# Example usage (replace with actual usage)
def main():
"""
Main function to process and save the frame with detections.
"""
drawer_saver = DrawingManager()

# Load frame and detection data (example)
frame = np.zeros((480, 640, 3), dtype=np.uint8)

Expand All @@ -210,15 +214,16 @@ def save_frame(self, frame_bytes: bytearray, output_filename: str) -> None:
points = [(100, 100), (250, 250), (450, 450), (500, 200), (150, 400)]
polygon = Polygon(points).convex_hull

# Initialise DrawingManager class
drawer_saver = DrawingManager()

# Draw detections on frame (including safety cones)
frame_with_detections = drawer_saver.draw_detections_on_frame(
frame, polygon, datas,
frame, [polygon], datas,
)

# Save the frame with detections
output_filename = 'frame_001'
frame_bytes = cv2.imencode('.png', frame_with_detections)[1].tobytes()
drawer_saver.save_frame(frame_bytes, output_filename)


if __name__ == '__main__':
main()
86 changes: 40 additions & 46 deletions src/live_stream_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,11 +242,9 @@ def remove_overlapping_labels(self, datas):
list: A list of detection data with overlapping labels removed.
"""
hardhat_indices = [
i
for i, d in enumerate(
i for i, d in enumerate(
datas,
)
if d[5] == 0
) if d[5] == 0
] # Indices of Hardhat detections
# Indices of NO-Hardhat detections
no_hardhat_indices = [i for i, d in enumerate(datas) if d[5] == 2]
Expand All @@ -258,29 +256,26 @@ def remove_overlapping_labels(self, datas):
to_remove = set()
for hardhat_index in hardhat_indices:
for no_hardhat_index in no_hardhat_indices:
if (
self.overlap_percentage(
datas[hardhat_index][:4],
datas[no_hardhat_index][:4],
)
> 0.8
):
overlap = self.overlap_percentage(
datas[hardhat_index][:4], datas[no_hardhat_index][:4],
)
if overlap > 0.8:
to_remove.add(no_hardhat_index)

for safety_vest_index in safety_vest_indices:
for no_safety_vest_index in no_safety_vest_indices:
if (
self.overlap_percentage(
datas[safety_vest_index][:4],
datas[no_safety_vest_index][:4],
)
> 0.8
):
overlap = self.overlap_percentage(
datas[safety_vest_index][:4],
datas[no_safety_vest_index][:4],
)
if overlap > 0.8:
to_remove.add(no_safety_vest_index)

for index in sorted(to_remove, reverse=True):
datas.pop(index)

return datas

gc.collect()
return datas

Expand Down Expand Up @@ -346,10 +341,13 @@ def remove_completely_contained_labels(self, datas):
)
if d[5] == 0
] # Indices of Hardhat detections

# Indices of NO-Hardhat detections
no_hardhat_indices = [i for i, d in enumerate(datas) if d[5] == 2]

# Indices of Safety Vest detections
safety_vest_indices = [i for i, d in enumerate(datas) if d[5] == 7]

# Indices of NO-Safety Vest detections
no_safety_vest_indices = [i for i, d in enumerate(datas) if d[5] == 4]

Expand Down Expand Up @@ -414,34 +412,26 @@ def run_detection(self, stream_url: str) -> None:
stream_url (str): The URL of the live stream.
"""
cap = cv2.VideoCapture(stream_url)
cap.set(cv2.CAP_PROP_BUFFERSIZE, 1)

while True:
ret, frame = cap.read()
if not ret:
print('Failed to read frame from the stream. Retrying...')
time.sleep(2)
continue

try:
datas, _ = self.generate_detections(frame)
print(f"datas: {datas}")
except Exception as e:
print(f"Detection error: {e}")

# Clear variables to free up memory
del frame, ret, datas
gc.collect()

# Break loop if 'q' key is pressed
if cv2.waitKey(1) & 0xFF == ord('q'):
break

cap.release()
cv2.destroyAllWindows()


if __name__ == '__main__':
if not cap.isOpened():
raise ValueError('Failed to open stream.')

try:
while True:
ret, frame = cap.read()
if not ret:
print('Failed to read frame from the stream. Retrying...')
continue

# Perform detection
cv2.imshow('Frame', frame)
if cv2.waitKey(1) & 0xFF == ord('q'):
break
finally:
cap.release()
cv2.destroyAllWindows()


def main():
parser = argparse.ArgumentParser(
description='Perform live stream detection and tracking using YOLOv8.',
)
Expand Down Expand Up @@ -482,3 +472,7 @@ def run_detection(self, stream_url: str) -> None:
run_local=args.run_local,
)
detector.run_detection(args.url)


if __name__ == '__main__':
main()
13 changes: 9 additions & 4 deletions src/monitor_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,13 +102,18 @@ def get_logger(self):
return self.logger


# This block is executed when the script is run directly, not when imported
if __name__ == '__main__':
# Example usage of the LoggerConfig class:

def main():
"""
Main function to initialise logger and log a message.
"""
# Initialise the logger configuration
logger_config = LoggerConfig()
logger = logger_config.get_logger()

# Log a message indicating that the logging setup is complete
logger.info('Logging setup complete.')


# This block is executed when the script is run directly, not when imported
if __name__ == '__main__':
main()
Loading

0 comments on commit ca286f0

Please sign in to comment.