-
Notifications
You must be signed in to change notification settings - Fork 3
/
pic_finder.py
executable file
·216 lines (169 loc) · 6.55 KB
/
pic_finder.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from pathlib import Path
import argparse
import multiprocessing
import numpy as np
import os
import sys
import cv2
"""
Find similar image by comparing image keypoints
The processing pipeline and choice of algorithms are not my own.
"""
# Globals for user interface
# Number of top results to print to the user
NTOP = 50
# ANSI escape code to clear line
ANSI_CL = '\033[K'
# Globals for image scanning algorithms
# Lowe ratio, essentially used to determine if two keypoints match
LOWE_RATIO = 0.7
# Requires OPENCV_ENABLE_NONFREE=ON when compiling OpenCV
_SIFT = cv2.SIFT_create()
# ORB detector does not perform well, but it is under a free license
#_ORB = cv2.ORB_create()
# Brute Force matcher has limited number of descriptors:
# cv2.error: OpenCV(4.5.0) ../modules/features2d/src/matchers.cpp:860: error: (-215:Assertion failed) trainDescCollection[iIdx].rows < IMGIDX_ONE in function 'knnMatchImpl'
# So we hardcode limit here:
#BF_IMGIDX_SHIFT = 18
#BF_IMGIDX_ONE = 1 << BF_IMGIDX_SHIFT
#_BF = cv2.BFMatcher()
FLANN_KNN_MATCHES = 2 # For _FLANN.knnMatch
FLANN_INDEX_KDTREE = 0
flann_index_params = dict(algorithm=FLANN_INDEX_KDTREE, trees=5)
flann_search_params = dict(checks=50) # or pass empty dictionary
_FLANN = cv2.FlannBasedMatcher(flann_index_params, flann_search_params)
# Globals for a worker in the process pool
# Loaded reference descriptor
_REFERENCE_DES: np.ndarray = None
# Maximum length of the larger dimension when downscaling images
_MAX_LENGTH: int = None
# Base code for image searching
class _InvalidComputation(Exception):
# Thrown on invalid computation for an image
pass
def downscale_img(img, max_length):
height, width = img.shape[:2]
if height < max_length and width < max_length:
# Image is already smaller than we need
return img
if height > width:
new_dim = (int(width / float(height) * max_length), max_length)
else:
new_dim = (max_length, int(height / float(width) * max_length))
return cv2.resize(img, new_dim, interpolation=cv2.INTER_AREA)
def compute_descriptor(img_path, max_length):
img = cv2.imread(str(img_path), 0)
if img is None:
raise _InvalidComputation('Not a valid image')
img = downscale_img(img, max_length)
#kp, des = _ORB.detectAndCompute(img, None)
kp, des = _SIFT.detectAndCompute(img, None)
#kp, des = _SURF.detectAndCompute(img, None)
#print(img_path, des.dtype, des.shape)
if not kp:
raise _InvalidComputation('Could not find any keypoints')
if des is None:
raise _InvalidComputation('descriptor is None')
return des
def get_good_matches(des1, des2):
# Because of https://github.com/opencv/opencv/issues/10548
if des2.shape[0] < FLANN_KNN_MATCHES:
raise _InvalidComputation('train descriptor has too few entries')
matches = _FLANN.knnMatch(des1, des2, k=FLANN_KNN_MATCHES)
if not matches:
raise _InvalidComputation('No matches found')
if len(matches[0]) != 2:
raise _InvalidComputation(f'Matches columns must have 2 entries, got {len(matches[0])}')
# ratio test as per Lowe's paper
#good = list()
ngood = 0
for m, n in matches:
if m.distance < LOWE_RATIO * n.distance:
#good.append(m)
ngood += 1
return ngood
# User interface common functions
def status_msg(*args, **kwargs):
print(ANSI_CL, end='')
print(*args, end='\r', **kwargs)
def error_msg(*args, **kwargs):
print(ANSI_CL, end='')
print(*args, file=sys.stderr, **kwargs)
def info_msg(*args, clear=False, **kwargs):
if clear:
print(ANSI_CL, end='')
print(*args, **kwargs)
# Process pool worker functions
def init_worker(img_path, max_length):
global _REFERENCE_DES, _MAX_LENGTH
assert _REFERENCE_DES is None
assert _MAX_LENGTH is None
_REFERENCE_DES = compute_descriptor(img_path, max_length)
_MAX_LENGTH = max_length
def compute_ngood(img_path):
assert _REFERENCE_DES is not None
if not img_path.is_file():
error_msg('Skipping non-file', img_path)
return
if img_path.is_symlink():
error_msg('Skipping symlink', img_path)
return
try:
status_msg('Processing', img_path)
try:
scan_des = compute_descriptor(img_path, _MAX_LENGTH)
except _InvalidComputation as exc:
error_msg(f'Skipping img with compute_descriptor error "{str(exc)}":', img_path)
return
try:
ngood = get_good_matches(_REFERENCE_DES, scan_des)
except _InvalidComputation as exc:
error_msg(f'Skipping img with get_good_matches error "{str(exc)}":', img_path)
return
except BaseException as exc:
error_msg(f'Threw exception on {img_path}: {exc}')
return
return ngood, img_path
# Main process functions
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
'--workers',
type=int,
help=
f'Number of worker subprocesses to launch. If not specified, defaults to the number of CPU threads (found: {os.cpu_count()}).'
)
parser.add_argument('--chunksize',
type=int,
default=2,
help='Chunksize for multiprocessing map operation. Default: %(default)s')
parser.add_argument(
'--resize',
type=int,
default=640,
help=
'Before computing descriptors, specifies the maximum length of the larger dimension when downscaling images. Default: %(default)s'
)
parser.add_argument('query_img', type=Path, help='Path to the query image')
parser.add_argument('img_root',
type=Path,
help='Path to the directory containing images to search through')
args = parser.parse_args()
status_msg('Initializing workers...')
with multiprocessing.Pool(args.workers, init_worker, [args.query_img, args.resize]) as pool:
# filter(None, ...) means we use the identity function (e.g. lambda x: x) to filter.
# Since compute_ngood only returns None (on error) or a tuple (on success), this will filter out all None
all_matches = list(
filter(
None,
pool.imap_unordered(compute_ngood,
args.img_root.rglob('*'),
chunksize=args.chunksize)))
all_matches.sort()
info_msg(f'Top {NTOP} matches:', clear=True)
for ngood, scan_path in all_matches[-NTOP:]:
info_msg(scan_path, f'({ngood})')
if __name__ == '__main__':
main()