From 8092518d9ca381ae9be43a386ed6c9bda6c755af Mon Sep 17 00:00:00 2001 From: YiHungWONG <40423264+yihong1120@users.noreply.github.com> Date: Fri, 9 Aug 2024 02:56:00 +0000 Subject: [PATCH] Reduce complexity --- examples/YOLOv8_server_api/detection.py | 217 +++++++++++++--------- examples/streaming_web/static/js/label.js | 72 ++++--- 2 files changed, 176 insertions(+), 113 deletions(-) diff --git a/examples/YOLOv8_server_api/detection.py b/examples/YOLOv8_server_api/detection.py index 8347142..756b24a 100644 --- a/examples/YOLOv8_server_api/detection.py +++ b/examples/YOLOv8_server_api/detection.py @@ -4,9 +4,7 @@ import cv2 import numpy as np -from flask import Blueprint -from flask import jsonify -from flask import request +from flask import Blueprint, jsonify, request from flask_jwt_extended import jwt_required from flask_limiter import Limiter from flask_limiter.util import get_remote_address @@ -27,12 +25,42 @@ def detect(): model_key = request.args.get('model', default='yolov8n', type=str) model = model_loader.get_model(model_key) # 从 DetectionModelManager 获取模型 - # Convert string data to numpy array + img = convert_to_image(data) + result = get_prediction_result(img, model) + + datas = compile_detection_data(result) + datas = process_labels(datas) + + return jsonify(datas) + + +def convert_to_image(data): + """ + Convert string data to an image. + + Args: + data (bytes): Image data in bytes. + + Returns: + numpy.ndarray: Decoded image. + """ npimg = np.frombuffer(data, np.uint8) - # Convert numpy array to image img = cv2.imdecode(npimg, cv2.IMREAD_COLOR) + return img + - result = get_sliced_prediction( +def get_prediction_result(img, model): + """ + Get the prediction result from the model. + + Args: + img (numpy.ndarray): Input image. + model: Detection model. + + Returns: + Result: Prediction result. + """ + return get_sliced_prediction( img, model, slice_height=376, @@ -41,24 +69,40 @@ def detect(): overlap_width_ratio=0.3, ) - # Compile detection data in YOLOv8 format + +def compile_detection_data(result): + """ + Compile detection data in YOLOv8 format. + + Args: + result: Prediction result. + + Returns: + list: Compiled detection data. + """ datas = [] for object_prediction in result.object_prediction_list: label = int(object_prediction.category.id) x1, y1, x2, y2 = (int(x) for x in object_prediction.bbox.to_voc_bbox()) confidence = float(object_prediction.score.value) datas.append([x1, y1, x2, y2, confidence, label]) + return datas - # Remove overlapping labels for Hardhat and Safety Vest categories - datas = remove_overlapping_labels(datas) - # Remove completely contained labels for Hardhat and Safety Vest categories - datas = remove_completely_contained_labels(datas) +def process_labels(datas): + """ + Process detection labels to remove overlaps and contained labels. - # Remove overlapping labels for Hardhat and Safety Vest categories - datas = remove_overlapping_labels(datas) + Args: + datas (list): Detection data. - return jsonify(datas) + Returns: + list: Processed detection data. + """ + datas = remove_overlapping_labels(datas) + datas = remove_completely_contained_labels(datas) + datas = remove_overlapping_labels(datas) + return datas def remove_overlapping_labels(datas): @@ -71,42 +115,11 @@ def remove_overlapping_labels(datas): Returns: list: A list of detection data with overlapping labels removed. """ - hardhat_indices = [ - i - for i, d in enumerate( - datas, - ) - if d[5] == 0 - ] # Indices of Hardhat detections - # Indices of NO-Hardhat detections - no_hardhat_indices = [i for i, d in enumerate(datas) if d[5] == 2] - # Indices of Safety Vest detections - safety_vest_indices = [i for i, d in enumerate(datas) if d[5] == 7] - # Indices of NO-Safety Vest detections - no_safety_vest_indices = [i for i, d in enumerate(datas) if d[5] == 4] + hardhat_indices, no_hardhat_indices, safety_vest_indices, no_safety_vest_indices = get_category_indices(datas) to_remove = set() - for hardhat_index in hardhat_indices: - for no_hardhat_index in no_hardhat_indices: - if ( - overlap_percentage( - datas[hardhat_index][:4], - datas[no_hardhat_index][:4], - ) - > 0.8 - ): - to_remove.add(no_hardhat_index) - - for safety_vest_index in safety_vest_indices: - for no_safety_vest_index in no_safety_vest_indices: - if ( - overlap_percentage( - datas[safety_vest_index][:4], - datas[no_safety_vest_index][:4], - ) - > 0.8 - ): - to_remove.add(no_safety_vest_index) + to_remove.update(find_overlaps(hardhat_indices, no_hardhat_indices, datas, 0.8)) + to_remove.update(find_overlaps(safety_vest_indices, no_safety_vest_indices, datas, 0.8)) for index in sorted(to_remove, reverse=True): datas.pop(index) @@ -115,6 +128,44 @@ def remove_overlapping_labels(datas): return datas +def get_category_indices(datas): + """ + Get indices of different categories in the detection data. + + Args: + datas (list): A list of detection data in YOLOv8 format. + + Returns: + tuple: Indices of Hardhat, NO-Hardhat, Safety Vest, and NO-Safety Vest detections. + """ + hardhat_indices = [i for i, d in enumerate(datas) if d[5] == 0] + no_hardhat_indices = [i for i, d in enumerate(datas) if d[5] == 2] + safety_vest_indices = [i for i, d in enumerate(datas) if d[5] == 7] + no_safety_vest_indices = [i for i, d in enumerate(datas) if d[5] == 4] + return hardhat_indices, no_hardhat_indices, safety_vest_indices, no_safety_vest_indices + + +def find_overlaps(indices1, indices2, datas, threshold): + """ + Find overlapping labels between two sets of indices. + + Args: + indices1 (list): First set of indices. + indices2 (list): Second set of indices. + datas (list): Detection data. + threshold (float): Overlap threshold. + + Returns: + set: Indices of overlapping labels to remove. + """ + to_remove = set() + for index1 in indices1: + for index2 in indices2: + if overlap_percentage(datas[index1][:4], datas[index2][:4]) > threshold: + to_remove.add(index2) + return to_remove + + def overlap_percentage(bbox1, bbox2): """ Calculates the percentage of overlap between two bounding boxes. @@ -135,11 +186,8 @@ def overlap_percentage(bbox1, bbox2): bbox1_area = (bbox1[2] - bbox1[0] + 1) * (bbox1[3] - bbox1[1] + 1) bbox2_area = (bbox2[2] - bbox2[0] + 1) * (bbox2[3] - bbox2[1] + 1) - overlap_percentage = intersection_area / float( - bbox1_area + bbox2_area - intersection_area, - ) + overlap_percentage = intersection_area / float(bbox1_area + bbox2_area - intersection_area) gc.collect() - return overlap_percentage @@ -172,50 +220,35 @@ def remove_completely_contained_labels(datas): Returns: list: Detection data with fully contained labels removed. """ - hardhat_indices = [ - i - for i, d in enumerate( - datas, - ) - if d[5] == 0 - ] # Indices of Hardhat detections - # Indices of NO-Hardhat detections - no_hardhat_indices = [i for i, d in enumerate(datas) if d[5] == 2] - # Indices of Safety Vest detections - safety_vest_indices = [i for i, d in enumerate(datas) if d[5] == 7] - # Indices of NO-Safety Vest detections - no_safety_vest_indices = [i for i, d in enumerate(datas) if d[5] == 4] + hardhat_indices, no_hardhat_indices, safety_vest_indices, no_safety_vest_indices = get_category_indices(datas) to_remove = set() - # Check hardhats - for hardhat_index in hardhat_indices: - for no_hardhat_index in no_hardhat_indices: - if is_contained( - datas[no_hardhat_index][:4], - datas[hardhat_index][:4], - ): - to_remove.add(no_hardhat_index) - elif is_contained( - datas[hardhat_index][:4], - datas[no_hardhat_index][:4], - ): - to_remove.add(hardhat_index) - - # Check safety vests - for safety_vest_index in safety_vest_indices: - for no_safety_vest_index in no_safety_vest_indices: - if is_contained( - datas[no_safety_vest_index][:4], - datas[safety_vest_index][:4], - ): - to_remove.add(no_safety_vest_index) - elif is_contained( - datas[safety_vest_index][:4], - datas[no_safety_vest_index][:4], - ): - to_remove.add(safety_vest_index) + to_remove.update(find_contained_labels(hardhat_indices, no_hardhat_indices, datas)) + to_remove.update(find_contained_labels(safety_vest_indices, no_safety_vest_indices, datas)) for index in sorted(to_remove, reverse=True): datas.pop(index) return datas + + +def find_contained_labels(indices1, indices2, datas): + """ + Find completely contained labels between two sets of indices. + + Args: + indices1 (list): First set of indices. + indices2 (list): Second set of indices. + datas (list): Detection data. + + Returns: + set: Indices of completely contained labels to remove. + """ + to_remove = set() + for index1 in indices1: + for index2 in indices2: + if is_contained(datas[index2][:4], datas[index1][:4]): + to_remove.add(index2) + elif is_contained(datas[index1][:4], datas[index2][:4]): + to_remove.add(index1) + return to_remove \ No newline at end of file diff --git a/examples/streaming_web/static/js/label.js b/examples/streaming_web/static/js/label.js index 0f5b41c..846f1a5 100644 --- a/examples/streaming_web/static/js/label.js +++ b/examples/streaming_web/static/js/label.js @@ -1,15 +1,15 @@ $(document).ready(() => { - // 自动检测当前页面协议,以决定 ws 还是 wss + // Automatically detect the current page protocol to decide between ws and wss const protocol = window.location.protocol === 'https:' ? 'wss://' : 'ws://'; - // 创建 WebSocket 连接,并配置重连策略 + // Create WebSocket connection and configure reconnection strategy const socket = io.connect(protocol + document.domain + ':' + location.port, { transports: ['websocket'], - reconnectionAttempts: 5, // 最多重连尝试 5 次 - reconnectionDelay: 2000 // 重连间隔为 2000 毫秒 + reconnectionAttempts: 5, // Maximum of 5 reconnection attempts + reconnectionDelay: 2000 // Reconnection interval of 2000 milliseconds }); - // 获取当前页面的标签名 - const currentPageLabel = $('h1').text(); // 假设页面的

标签包含了当前的标签名称 + // Get the label of the current page + const currentPageLabel = $('h1').text(); // Assuming the

tag contains the current label name socket.on('connect', () => { console.log('WebSocket connected!'); @@ -24,20 +24,50 @@ $(document).ready(() => { }); socket.on('update', (data) => { - // 检查接收到的数据是否适用于当前页面的标签 - if (data.label === currentPageLabel) { - console.log('Received update for current label:', data.label); - const fragment = document.createDocumentFragment(); - data.images.forEach((image, index) => { - const cameraDiv = $('
').addClass('camera'); - const title = $('

').text(data.image_names[index]); - const img = $('').attr('src', `data:image/png;base64,${image}`).attr('alt', `${data.label} image`); - cameraDiv.append(title).append(img); - fragment.appendChild(cameraDiv[0]); - }); - $('.camera-grid').empty().append(fragment); - } else { - console.log('Received update for different label:', data.label); - } + handleUpdate(data, currentPageLabel); }); }); + +/** + * Handle WebSocket updates + * @param {Object} data - The received data + * @param {string} currentPageLabel - The label of the current page + */ +function handleUpdate(data, currentPageLabel) { + // Check if the received data is applicable to the current page's label + if (data.label === currentPageLabel) { + console.log('Received update for current label:', data.label); + updateCameraGrid(data); + } else { + console.log('Received update for different label:', data.label); + } +} + +/** + * Update the camera grid + * @param {Object} data - The data containing images and names + */ +function updateCameraGrid(data) { + const fragment = document.createDocumentFragment(); + data.images.forEach((image, index) => { + const cameraDiv = createCameraDiv(image, index, data.image_names, data.label); + fragment.appendChild(cameraDiv); + }); + $('.camera-grid').empty().append(fragment); +} + +/** + * Create a camera div element + * @param {string} image - The image data + * @param {number} index - The image index + * @param {Array} imageNames - The array of image names + * @param {string} label - The label name + * @returns {HTMLElement} - The div element containing the image and title + */ +function createCameraDiv(image, index, imageNames, label) { + const cameraDiv = $('
').addClass('camera'); + const title = $('

').text(imageNames[index]); + const img = $('').attr('src', `data:image/png;base64,${image}`).attr('alt', `${label} image`); + cameraDiv.append(title).append(img); + return cameraDiv[0]; +} \ No newline at end of file