-
-
Notifications
You must be signed in to change notification settings - Fork 24
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
0d3c012
commit 6b0b9c9
Showing
1 changed file
with
92 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
from __future__ import annotations | ||
|
||
from io import BytesIO | ||
|
||
import cv2 | ||
import numpy as np | ||
import pytest | ||
from flask import Flask | ||
from flask_jwt_extended import create_access_token | ||
from flask_jwt_extended import JWTManager | ||
|
||
from examples.YOLOv8_server_api.detection import detection_blueprint | ||
from examples.YOLOv8_server_api.detection import is_contained | ||
from examples.YOLOv8_server_api.detection import overlap_percentage | ||
from examples.YOLOv8_server_api.detection import ( | ||
remove_completely_contained_labels, | ||
) | ||
from examples.YOLOv8_server_api.detection import remove_overlapping_labels | ||
|
||
app = Flask(__name__) | ||
# Change this in your real application | ||
app.config['JWT_SECRET_KEY'] = 'super-secret' | ||
jwt = JWTManager(app) | ||
app.register_blueprint(detection_blueprint) | ||
|
||
|
||
@pytest.fixture | ||
def client(): | ||
with app.test_client() as client: | ||
yield client | ||
|
||
|
||
def test_detection_route(client): | ||
# 创建 JWT token | ||
with app.app_context(): | ||
access_token = create_access_token(identity='testuser') | ||
|
||
# 加载测试图像 | ||
img = np.zeros((500, 500, 3), dtype=np.uint8) | ||
_, buffer = cv2.imencode('.jpg', img) | ||
img_bytes = BytesIO(buffer.tobytes()) | ||
|
||
# 测试检测端点 | ||
response = client.post( | ||
'/detect', | ||
headers={'Authorization': f'Bearer {access_token}'}, | ||
content_type='multipart/form-data', | ||
data={'image': (img_bytes, 'test.jpg')}, | ||
) | ||
|
||
assert response.status_code == 200 | ||
assert isinstance(response.json, list) | ||
|
||
|
||
def test_remove_overlapping_labels(): | ||
datas = [ | ||
[10, 10, 50, 50, 0.9, 0], # Hardhat | ||
[10, 10, 50, 50, 0.8, 2], # NO-Hardhat | ||
[100, 100, 150, 150, 0.9, 7], # Safety Vest | ||
[100, 100, 150, 150, 0.8, 4], # NO-Safety Vest | ||
] | ||
|
||
updated_datas = remove_overlapping_labels(datas) | ||
assert len(updated_datas) == 2 | ||
assert all(d[5] in [0, 7] for d in updated_datas) | ||
|
||
|
||
def test_remove_completely_contained_labels(): | ||
datas = [ | ||
[10, 10, 50, 50, 0.9, 0], # Hardhat | ||
[15, 15, 45, 45, 0.8, 2], # NO-Hardhat contained within Hardhat | ||
[100, 100, 150, 150, 0.9, 7], # Safety Vest | ||
# NO-Safety Vest contained within Safety Vest | ||
[105, 105, 145, 145, 0.8, 4], | ||
] | ||
|
||
updated_datas = remove_completely_contained_labels(datas) | ||
assert len(updated_datas) == 2 | ||
assert all(d[5] in [0, 7] for d in updated_datas) | ||
|
||
|
||
def test_overlap_percentage(): | ||
bbox1 = [10, 10, 50, 50] | ||
bbox2 = [30, 30, 70, 70] | ||
assert overlap_percentage(bbox1, bbox2) > 0 | ||
|
||
|
||
def test_is_contained(): | ||
outer_bbox = [10, 10, 50, 50] | ||
inner_bbox = [20, 20, 30, 30] | ||
assert is_contained(inner_bbox, outer_bbox) | ||
assert not is_contained(outer_bbox, inner_bbox) |