Skip to content

Commit

Permalink
Reduce complexity
Browse files Browse the repository at this point in the history
  • Loading branch information
yihong1120 committed Aug 9, 2024
1 parent 92edb1f commit 8092518
Show file tree
Hide file tree
Showing 2 changed files with 176 additions and 113 deletions.
217 changes: 125 additions & 92 deletions examples/YOLOv8_server_api/detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,7 @@

import cv2
import numpy as np
from flask import Blueprint
from flask import jsonify
from flask import request
from flask import Blueprint, jsonify, request
from flask_jwt_extended import jwt_required
from flask_limiter import Limiter
from flask_limiter.util import get_remote_address
Expand All @@ -27,12 +25,42 @@ def detect():
model_key = request.args.get('model', default='yolov8n', type=str)
model = model_loader.get_model(model_key) # 从 DetectionModelManager 获取模型

# Convert string data to numpy array
img = convert_to_image(data)
result = get_prediction_result(img, model)

datas = compile_detection_data(result)
datas = process_labels(datas)

return jsonify(datas)


def convert_to_image(data):
"""
Convert string data to an image.
Args:
data (bytes): Image data in bytes.
Returns:
numpy.ndarray: Decoded image.
"""
npimg = np.frombuffer(data, np.uint8)
# Convert numpy array to image
img = cv2.imdecode(npimg, cv2.IMREAD_COLOR)
return img


result = get_sliced_prediction(
def get_prediction_result(img, model):
"""
Get the prediction result from the model.
Args:
img (numpy.ndarray): Input image.
model: Detection model.
Returns:
Result: Prediction result.
"""
return get_sliced_prediction(
img,
model,
slice_height=376,
Expand All @@ -41,24 +69,40 @@ def detect():
overlap_width_ratio=0.3,
)

# Compile detection data in YOLOv8 format

def compile_detection_data(result):
"""
Compile detection data in YOLOv8 format.
Args:
result: Prediction result.
Returns:
list: Compiled detection data.
"""
datas = []
for object_prediction in result.object_prediction_list:
label = int(object_prediction.category.id)
x1, y1, x2, y2 = (int(x) for x in object_prediction.bbox.to_voc_bbox())
confidence = float(object_prediction.score.value)
datas.append([x1, y1, x2, y2, confidence, label])
return datas

# Remove overlapping labels for Hardhat and Safety Vest categories
datas = remove_overlapping_labels(datas)

# Remove completely contained labels for Hardhat and Safety Vest categories
datas = remove_completely_contained_labels(datas)
def process_labels(datas):
"""
Process detection labels to remove overlaps and contained labels.
# Remove overlapping labels for Hardhat and Safety Vest categories
datas = remove_overlapping_labels(datas)
Args:
datas (list): Detection data.
return jsonify(datas)
Returns:
list: Processed detection data.
"""
datas = remove_overlapping_labels(datas)
datas = remove_completely_contained_labels(datas)
datas = remove_overlapping_labels(datas)
return datas


def remove_overlapping_labels(datas):
Expand All @@ -71,42 +115,11 @@ def remove_overlapping_labels(datas):
Returns:
list: A list of detection data with overlapping labels removed.
"""
hardhat_indices = [
i
for i, d in enumerate(
datas,
)
if d[5] == 0
] # Indices of Hardhat detections
# Indices of NO-Hardhat detections
no_hardhat_indices = [i for i, d in enumerate(datas) if d[5] == 2]
# Indices of Safety Vest detections
safety_vest_indices = [i for i, d in enumerate(datas) if d[5] == 7]
# Indices of NO-Safety Vest detections
no_safety_vest_indices = [i for i, d in enumerate(datas) if d[5] == 4]
hardhat_indices, no_hardhat_indices, safety_vest_indices, no_safety_vest_indices = get_category_indices(datas)

to_remove = set()
for hardhat_index in hardhat_indices:
for no_hardhat_index in no_hardhat_indices:
if (
overlap_percentage(
datas[hardhat_index][:4],
datas[no_hardhat_index][:4],
)
> 0.8
):
to_remove.add(no_hardhat_index)

for safety_vest_index in safety_vest_indices:
for no_safety_vest_index in no_safety_vest_indices:
if (
overlap_percentage(
datas[safety_vest_index][:4],
datas[no_safety_vest_index][:4],
)
> 0.8
):
to_remove.add(no_safety_vest_index)
to_remove.update(find_overlaps(hardhat_indices, no_hardhat_indices, datas, 0.8))
to_remove.update(find_overlaps(safety_vest_indices, no_safety_vest_indices, datas, 0.8))

for index in sorted(to_remove, reverse=True):
datas.pop(index)
Expand All @@ -115,6 +128,44 @@ def remove_overlapping_labels(datas):
return datas


def get_category_indices(datas):
"""
Get indices of different categories in the detection data.
Args:
datas (list): A list of detection data in YOLOv8 format.
Returns:
tuple: Indices of Hardhat, NO-Hardhat, Safety Vest, and NO-Safety Vest detections.
"""
hardhat_indices = [i for i, d in enumerate(datas) if d[5] == 0]
no_hardhat_indices = [i for i, d in enumerate(datas) if d[5] == 2]
safety_vest_indices = [i for i, d in enumerate(datas) if d[5] == 7]
no_safety_vest_indices = [i for i, d in enumerate(datas) if d[5] == 4]
return hardhat_indices, no_hardhat_indices, safety_vest_indices, no_safety_vest_indices


def find_overlaps(indices1, indices2, datas, threshold):
"""
Find overlapping labels between two sets of indices.
Args:
indices1 (list): First set of indices.
indices2 (list): Second set of indices.
datas (list): Detection data.
threshold (float): Overlap threshold.
Returns:
set: Indices of overlapping labels to remove.
"""
to_remove = set()
for index1 in indices1:
for index2 in indices2:
if overlap_percentage(datas[index1][:4], datas[index2][:4]) > threshold:
to_remove.add(index2)
return to_remove


def overlap_percentage(bbox1, bbox2):
"""
Calculates the percentage of overlap between two bounding boxes.
Expand All @@ -135,11 +186,8 @@ def overlap_percentage(bbox1, bbox2):
bbox1_area = (bbox1[2] - bbox1[0] + 1) * (bbox1[3] - bbox1[1] + 1)
bbox2_area = (bbox2[2] - bbox2[0] + 1) * (bbox2[3] - bbox2[1] + 1)

overlap_percentage = intersection_area / float(
bbox1_area + bbox2_area - intersection_area,
)
overlap_percentage = intersection_area / float(bbox1_area + bbox2_area - intersection_area)
gc.collect()

return overlap_percentage


Expand Down Expand Up @@ -172,50 +220,35 @@ def remove_completely_contained_labels(datas):
Returns:
list: Detection data with fully contained labels removed.
"""
hardhat_indices = [
i
for i, d in enumerate(
datas,
)
if d[5] == 0
] # Indices of Hardhat detections
# Indices of NO-Hardhat detections
no_hardhat_indices = [i for i, d in enumerate(datas) if d[5] == 2]
# Indices of Safety Vest detections
safety_vest_indices = [i for i, d in enumerate(datas) if d[5] == 7]
# Indices of NO-Safety Vest detections
no_safety_vest_indices = [i for i, d in enumerate(datas) if d[5] == 4]
hardhat_indices, no_hardhat_indices, safety_vest_indices, no_safety_vest_indices = get_category_indices(datas)

to_remove = set()
# Check hardhats
for hardhat_index in hardhat_indices:
for no_hardhat_index in no_hardhat_indices:
if is_contained(
datas[no_hardhat_index][:4],
datas[hardhat_index][:4],
):
to_remove.add(no_hardhat_index)
elif is_contained(
datas[hardhat_index][:4],
datas[no_hardhat_index][:4],
):
to_remove.add(hardhat_index)

# Check safety vests
for safety_vest_index in safety_vest_indices:
for no_safety_vest_index in no_safety_vest_indices:
if is_contained(
datas[no_safety_vest_index][:4],
datas[safety_vest_index][:4],
):
to_remove.add(no_safety_vest_index)
elif is_contained(
datas[safety_vest_index][:4],
datas[no_safety_vest_index][:4],
):
to_remove.add(safety_vest_index)
to_remove.update(find_contained_labels(hardhat_indices, no_hardhat_indices, datas))
to_remove.update(find_contained_labels(safety_vest_indices, no_safety_vest_indices, datas))

for index in sorted(to_remove, reverse=True):
datas.pop(index)

return datas


def find_contained_labels(indices1, indices2, datas):
"""
Find completely contained labels between two sets of indices.
Args:
indices1 (list): First set of indices.
indices2 (list): Second set of indices.
datas (list): Detection data.
Returns:
set: Indices of completely contained labels to remove.
"""
to_remove = set()
for index1 in indices1:
for index2 in indices2:
if is_contained(datas[index2][:4], datas[index1][:4]):
to_remove.add(index2)
elif is_contained(datas[index1][:4], datas[index2][:4]):
to_remove.add(index1)
return to_remove
72 changes: 51 additions & 21 deletions examples/streaming_web/static/js/label.js
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
$(document).ready(() => {
// 自动检测当前页面协议,以决定 ws 还是 wss
// Automatically detect the current page protocol to decide between ws and wss
const protocol = window.location.protocol === 'https:' ? 'wss://' : 'ws://';
// 创建 WebSocket 连接,并配置重连策略
// Create WebSocket connection and configure reconnection strategy
const socket = io.connect(protocol + document.domain + ':' + location.port, {
transports: ['websocket'],
reconnectionAttempts: 5, // 最多重连尝试 5 次
reconnectionDelay: 2000 // 重连间隔为 2000 毫秒
reconnectionAttempts: 5, // Maximum of 5 reconnection attempts
reconnectionDelay: 2000 // Reconnection interval of 2000 milliseconds
});

// 获取当前页面的标签名
const currentPageLabel = $('h1').text(); // 假设页面的 <h1> 标签包含了当前的标签名称
// Get the label of the current page
const currentPageLabel = $('h1').text(); // Assuming the <h1> tag contains the current label name

socket.on('connect', () => {
console.log('WebSocket connected!');
Expand All @@ -24,20 +24,50 @@ $(document).ready(() => {
});

socket.on('update', (data) => {
// 检查接收到的数据是否适用于当前页面的标签
if (data.label === currentPageLabel) {
console.log('Received update for current label:', data.label);
const fragment = document.createDocumentFragment();
data.images.forEach((image, index) => {
const cameraDiv = $('<div>').addClass('camera');
const title = $('<h2>').text(data.image_names[index]);
const img = $('<img>').attr('src', `data:image/png;base64,${image}`).attr('alt', `${data.label} image`);
cameraDiv.append(title).append(img);
fragment.appendChild(cameraDiv[0]);
});
$('.camera-grid').empty().append(fragment);
} else {
console.log('Received update for different label:', data.label);
}
handleUpdate(data, currentPageLabel);
});
});

/**
* Handle WebSocket updates
* @param {Object} data - The received data
* @param {string} currentPageLabel - The label of the current page
*/
function handleUpdate(data, currentPageLabel) {
// Check if the received data is applicable to the current page's label
if (data.label === currentPageLabel) {
console.log('Received update for current label:', data.label);
updateCameraGrid(data);
} else {
console.log('Received update for different label:', data.label);
}
}

/**
* Update the camera grid
* @param {Object} data - The data containing images and names
*/
function updateCameraGrid(data) {
const fragment = document.createDocumentFragment();
data.images.forEach((image, index) => {
const cameraDiv = createCameraDiv(image, index, data.image_names, data.label);
fragment.appendChild(cameraDiv);
});
$('.camera-grid').empty().append(fragment);
}

/**
* Create a camera div element
* @param {string} image - The image data
* @param {number} index - The image index
* @param {Array} imageNames - The array of image names
* @param {string} label - The label name
* @returns {HTMLElement} - The div element containing the image and title
*/
function createCameraDiv(image, index, imageNames, label) {
const cameraDiv = $('<div>').addClass('camera');
const title = $('<h2>').text(imageNames[index]);
const img = $('<img>').attr('src', `data:image/png;base64,${image}`).attr('alt', `${label} image`);
cameraDiv.append(title).append(img);
return cameraDiv[0];
}

0 comments on commit 8092518

Please sign in to comment.