Skip to content

Commit

Permalink
Initial upload
Browse files Browse the repository at this point in the history
  • Loading branch information
yihong1120 committed Aug 4, 2024
1 parent 0d3c012 commit 6b0b9c9
Showing 1 changed file with 92 additions and 0 deletions.
92 changes: 92 additions & 0 deletions tests/detection_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
from __future__ import annotations

from io import BytesIO

import cv2
import numpy as np
import pytest
from flask import Flask
from flask_jwt_extended import create_access_token
from flask_jwt_extended import JWTManager

from examples.YOLOv8_server_api.detection import detection_blueprint
from examples.YOLOv8_server_api.detection import is_contained
from examples.YOLOv8_server_api.detection import overlap_percentage
from examples.YOLOv8_server_api.detection import (
remove_completely_contained_labels,
)
from examples.YOLOv8_server_api.detection import remove_overlapping_labels

app = Flask(__name__)
# Change this in your real application
app.config['JWT_SECRET_KEY'] = 'super-secret'
jwt = JWTManager(app)
app.register_blueprint(detection_blueprint)


@pytest.fixture
def client():
with app.test_client() as client:
yield client


def test_detection_route(client):
# 创建 JWT token
with app.app_context():
access_token = create_access_token(identity='testuser')

# 加载测试图像
img = np.zeros((500, 500, 3), dtype=np.uint8)
_, buffer = cv2.imencode('.jpg', img)
img_bytes = BytesIO(buffer.tobytes())

# 测试检测端点
response = client.post(
'/detect',
headers={'Authorization': f'Bearer {access_token}'},
content_type='multipart/form-data',
data={'image': (img_bytes, 'test.jpg')},
)

assert response.status_code == 200
assert isinstance(response.json, list)


def test_remove_overlapping_labels():
datas = [
[10, 10, 50, 50, 0.9, 0], # Hardhat
[10, 10, 50, 50, 0.8, 2], # NO-Hardhat
[100, 100, 150, 150, 0.9, 7], # Safety Vest
[100, 100, 150, 150, 0.8, 4], # NO-Safety Vest
]

updated_datas = remove_overlapping_labels(datas)
assert len(updated_datas) == 2
assert all(d[5] in [0, 7] for d in updated_datas)


def test_remove_completely_contained_labels():
datas = [
[10, 10, 50, 50, 0.9, 0], # Hardhat
[15, 15, 45, 45, 0.8, 2], # NO-Hardhat contained within Hardhat
[100, 100, 150, 150, 0.9, 7], # Safety Vest
# NO-Safety Vest contained within Safety Vest
[105, 105, 145, 145, 0.8, 4],
]

updated_datas = remove_completely_contained_labels(datas)
assert len(updated_datas) == 2
assert all(d[5] in [0, 7] for d in updated_datas)


def test_overlap_percentage():
bbox1 = [10, 10, 50, 50]
bbox2 = [30, 30, 70, 70]
assert overlap_percentage(bbox1, bbox2) > 0


def test_is_contained():
outer_bbox = [10, 10, 50, 50]
inner_bbox = [20, 20, 30, 30]
assert is_contained(inner_bbox, outer_bbox)
assert not is_contained(outer_bbox, inner_bbox)

0 comments on commit 6b0b9c9

Please sign in to comment.