-
Notifications
You must be signed in to change notification settings - Fork 0
/
process.py
142 lines (119 loc) · 5.47 KB
/
process.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
import numpy as np
import json
from PIL import ImageDraw, Image, ImageFont
from ultralytics import YOLO
from vietocr.tool.predictor import Predictor
from vietocr.tool.config import Cfg
import cv2
import torch
import os
import preprocessing as pre
import realesrgan as real
import postprocessing as post
class_name_dict = {0: 'addr',
1: 'addr_line2',
2: 'class',
3: 'day',
4: 'dob',
5: 'expiry',
6: 'license_no',
7: 'month',
8: 'name',
9: 'nationality',
10: 'place',
11: 'year'}
def read_image(image_path):
img = Image.open(image_path) # open image file
return img
def model_result(model_path, img):
model = YOLO(model_path)
results = model(img)[0]
return results
def set_detector():
config = Cfg.load_config_from_name('vgg_transformer')
#config['weights'] = '/content/drive/MyDrive/license-extractor/weight/vgg_transformer.pth'
config['cnn']['pretrained']=False
config['device'] = 'cpu'
detector = Predictor(config)
return detector
def ocr_extract(img_path,model_path, threshold):#input image is a PIL image
img = read_image(img_path)
labeled_objects = []
results = model_result(model_path, img)
detector = set_detector()
i = 0
for result in results.boxes.data.tolist():
x1, y1, x2, y2, score, class_id = result
if score > threshold:
# Draw a bounding box around the object
#cv2.rectangle(img, (int(x1-1), int(y1+1)), (int(x2+1), int(y2-1)), (0, 255, 0), 2)
img_array = np.array(img)
#ocr_img = img[int(y1):int(y2), int(x1):int(x2), :].copy()
roi = img_array[int(y1):int(y2), int(x1):int(x2)]
ocr_img = Image.fromarray(roi)
ocr_img.save('temp/o'+str(i)+'.jpg')
i = i + 1
#preprocess ROI images
#roi = cv2.cvtColor(ocr_img, cv2.COLOR_BGR2GRAY)
# _,roi = cv2.threshold(roi,128,255, cv2.THRESH_BINARY+cv2.THRESH_OTSU)
# Perform OCR on the ROI and print the detected text using Tesseract
text, prob = detector.predict(ocr_img, return_prob=True)
# Check if the class_id exists in class_name_dict
if class_id in class_name_dict:
class_name = class_name_dict[class_id]
if class_name != 'person': # not to write the portrait info
# Store the relevant information in the labeled_objects list
labeled_objects.append({
"class": class_name,
"text": text.strip(),
"prob": prob,
"confidence": score
})
# Save labeled objects information to a JSON file
with open('temp/labeled_objects.json', 'w', encoding='utf-8') as json_file:
json.dump(labeled_objects, json_file, ensure_ascii=False, indent=4)
labeled_objects = [] #clear data after use
# Save the modified image with bounding boxes drawn
# Create a drawing object to draw on the image
# show which class in each box
# Save the modified image with bounding boxes drawn
def save_output_image(img_path, model_path, threshold, name):
img = read_image(img_path)
draw = ImageDraw.Draw(img)
results = model_result(model_path, img)
font = ImageFont.truetype("arial.ttf", 36) # Use a TTF font file and specify the font size
# Define the text color
text_color = (255, 0, 0) # RGB color
for result in results.boxes.data.tolist():
x1, y1, x2, y2, score, class_id = result
if score > threshold:
if class_id in class_name_dict:
class_name = class_name_dict[class_id]
# Draw a bounding box around the object
draw.rectangle([(int(x1), int(y1)), (int(x2), int(y2))], outline=(0, 255, 0), width=3)
# Define the position for the text
text_position = (int(x1), int(y1) - 10) # Adjust the text position as needed
# Draw the text on the image
draw.text(text_position, f"{class_name}", fill=text_color, font=font)
img.save(f'output/{name}/output.jpg')
def processing(input_img_path, name):
print("Cuda available: ",torch.cuda.is_available())
#device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # not need to change
directory_path = f'output/{name}/'
if not os.path.exists(directory_path):
# If it doesn't exist, create the directory and any necessary parent directories
os.makedirs(directory_path)
else:
print(f"The {directory_path} is exist")
#preprocessed = pre.preprocessing(input_img_path)
model_path = 'license-yolov8xv4.pt' #model path
img = cv2.imread(input_img_path)
real_in_path = 'temp/real_in.jpg'
real_in = cv2.imwrite(real_in_path, img)
real_out = real.realesrgan(real_in_path)
ocr_extract(real_out, model_path, 0.5)
save_output_image(real_out, model_path, 0.5, name)
post.filter("temp/labeled_objects.json", name)
post.merge_permanent_residence(f'output/{name}/labeled_objects_filled.json', name)
post.mean_prob(f'output/{name}/labeled_objects_filled.json', name)
#processing('input/20231002_150945.jpg', 'thinh1')