Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

similar_image #381

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions app/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
METADATA_SERVER_URL = getattr(seahub_settings, 'METADATA_SERVER_URL', '')
ENABLE_METADATA_MANAGEMENT = getattr(seahub_settings, 'ENABLE_METADATA_MANAGEMENT', False)
METADATA_FILE_TYPES = getattr(seahub_settings, 'METADATA_FILE_TYPES', {})
SEAFILE_AI_SERVER_URL = getattr(seahub_settings, 'SEAFILE_AI_SERVER_URL', '')
SEAFILE_AI_SECRET_KEY = getattr(seahub_settings, 'SEAFILE_AI_SECRET_KEY', '')
except ImportError:
logger.critical("Can not import seahub settings.")
raise RuntimeError("Can not import seahub settings.")
Expand Down
2 changes: 1 addition & 1 deletion repo_metadata/metadata_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def update_metadata_index(self, repo_id, old_commit_id, new_commit_id):
renamed_dirs, moved_dirs = files

self.repo_metadata.update(repo_id, added_files, deleted_files, added_dirs, deleted_dirs, modified_files,
renamed_files, moved_files, renamed_dirs, moved_dirs, new_commit_id)
renamed_files, moved_files, renamed_dirs, moved_dirs)

def recovery(self, repo_id, from_commit, to_commit):
logger.warning('%s: metadata in recovery', repo_id)
Expand Down
57 changes: 49 additions & 8 deletions repo_metadata/repo_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def update_renamed_or_moved_files(self, repo_id, renamed_or_moved_files):
if not renamed_or_moved_files:
return

obj_ids = []
base_sql = f'SELECT `{METADATA_TABLE.columns.id.name}`, `{METADATA_TABLE.columns.obj_id.name}` FROM `{METADATA_TABLE.name}` WHERE `_obj_id` IN ('
sql = base_sql
obj_id_to_file_dict = {}
Expand All @@ -66,26 +67,47 @@ def update_renamed_or_moved_files(self, repo_id, renamed_or_moved_files):
sql += f'?, '
parameters.append(obj_id)

file_type, file_ext = get_file_type_ext_by_name(os.path.basename(path))
if file_type == '_picture' and file_ext != 'gif':
obj_ids.append(file.obj_id)

if len(parameters) >= METADATA_OP_LIMIT:
sql = sql.rstrip(', ') + ')'
self.update_rows_by_obj_ids(repo_id, sql, parameters, obj_id_to_file_dict)
sql = base_sql
parameters = []
obj_id_to_file_dict = {}

if obj_ids:
data = {
'task_type': 'modify_image_index',
'repo_id': repo_id,
'obj_ids': obj_ids
}
self.add_slow_task_to_queue(json.dumps(data))
obj_ids = []
if parameters:
sql = sql.rstrip(', ') + ')'
self.update_rows_by_obj_ids(repo_id, sql, parameters, obj_id_to_file_dict)

if obj_ids:
data = {
'task_type': 'modify_image_index',
'repo_id': repo_id,
'obj_ids': obj_ids
}
self.add_slow_task_to_queue(json.dumps(data))

def update(self, repo_id, added_files, deleted_files, added_dirs, deleted_dirs, modified_files,
renamed_files, moved_files, renamed_dirs, moved_dirs, commit_id):
renamed_files, moved_files, renamed_dirs, moved_dirs):

new_added_files, new_deleted_files, renamed_or_moved_files = self.cal_renamed_and_moved_files(added_files, deleted_files)

# delete added_files delete added dirs for preventing duplicate insertions
self.delete_files(repo_id, new_added_files)
self.delete_dirs(repo_id, added_dirs)

self.add_files(repo_id, new_added_files, commit_id)
self.add_files(repo_id, new_added_files)
self.delete_files(repo_id, new_deleted_files)
# update renamed or moved files
self.update_renamed_or_moved_files(repo_id, renamed_or_moved_files)
Expand Down Expand Up @@ -187,7 +209,7 @@ def update_rows_by_obj_ids(self, repo_id, sql, parameters, obj_id_to_file_dict):
return
self.metadata_server_api.update_rows(repo_id, METADATA_TABLE.id, updated_rows)

def add_files(self, repo_id, added_files, commit_id):
def add_files(self, repo_id, added_files):
if not added_files:
return

Expand Down Expand Up @@ -221,7 +243,7 @@ def add_files(self, repo_id, added_files, commit_id):

if file_type:
row[METADATA_TABLE.columns.file_type.name] = file_type
if file_type == '_picture' and file_ext not in ('png', 'gif'):
if file_type == '_picture' and file_ext != 'gif':
obj_ids.append(de.obj_id)
rows.append(row)

Expand All @@ -230,9 +252,8 @@ def add_files(self, repo_id, added_files, commit_id):

if obj_ids:
data = {
'task_type': 'location_extract',
'task_type': 'image_info_extract',
'repo_id': repo_id,
'commit_id': commit_id,
'obj_ids': obj_ids
}
self.add_slow_task_to_queue(json.dumps(data))
Expand All @@ -243,9 +264,8 @@ def add_files(self, repo_id, added_files, commit_id):
self.metadata_server_api.insert_rows(repo_id, METADATA_TABLE.id, rows)
if obj_ids:
data = {
'task_type': 'location_extract',
'task_type': 'image_info_extract',
'repo_id': repo_id,
'commit_id': commit_id,
'obj_ids': obj_ids
}
self.add_slow_task_to_queue(json.dumps(data))
Expand All @@ -257,6 +277,7 @@ def delete_files(self, repo_id, deleted_files):
if not deleted_files:
return

paths = []
base_sql = f'SELECT `{METADATA_TABLE.columns.id.name}` FROM `{METADATA_TABLE.name}` WHERE'
sql = base_sql
parameters = []
Expand All @@ -269,18 +290,38 @@ def delete_files(self, repo_id, deleted_files):
sql += f' (`{METADATA_TABLE.columns.parent_dir.name}` = ? AND `{METADATA_TABLE.columns.file_name.name}` = ?) OR'
parameters.append(parent_dir)
parameters.append(file_name)
file_type, file_ext = get_file_type_ext_by_name(file_name)
if file_type == '_picture' and file_ext != 'gif':
paths.append(path)

if len(parameters) >= METADATA_OP_LIMIT:
sql = sql.rstrip(' OR')
self.delete_rows_by_query(repo_id, sql, parameters)
sql = base_sql
parameters = []

if paths:
data = {
'task_type': 'delete_image_index',
'repo_id': repo_id,
'paths': paths
}
self.add_slow_task_to_queue(json.dumps(data))
paths = []

if not parameters:
return
sql = sql.rstrip(' OR')
self.delete_rows_by_query(repo_id, sql, parameters)

if paths:
data = {
'task_type': 'delete_image_index',
'repo_id': repo_id,
'paths': paths
}
self.add_slow_task_to_queue(json.dumps(data))

def update_modified_files(self, repo_id, modified_files):
if not modified_files:
return
Expand Down
35 changes: 35 additions & 0 deletions repo_metadata/seafile_ai_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import requests, jwt, time

from seafevents.app.config import SEAFILE_AI_SERVER_URL, SEAFILE_AI_SECRET_KEY


def parse_response(response):
if response.status_code >= 400 or response.status_code < 200:
raise ConnectionError(response.status_code, response.text)
else:
try:
return response.json()
except:
pass


class SeafileAIAPI:
def __init__(self, timeout=30):
self.timeout = timeout
self.secret_key = SEAFILE_AI_SECRET_KEY
self.server_url = SEAFILE_AI_SERVER_URL

def gen_headers(self):
payload = {'exp': int(time.time()) + 300, }
token = jwt.encode(payload, self.secret_key, algorithm='HS256')
return {"Authorization": "Token %s" % token}

def images_embedding(self, repo_id, obj_ids):
headers = self.gen_headers()
url = f'{self.server_url}/api/v1/images-embedding/'
data = {
'repo_id': repo_id,
'obj_ids': obj_ids,
}
response = requests.post(url, json=data, headers=headers, timeout=self.timeout)
return parse_response(response)
109 changes: 100 additions & 9 deletions repo_metadata/slow_task_handler.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os.path
import time
import logging
import threading
Expand All @@ -10,6 +11,11 @@
from seafevents.repo_metadata.metadata_server_api import MetadataServerAPI
from seafevents.repo_metadata.repo_metadata import METADATA_OP_LIMIT
from seafevents.repo_metadata.utils import METADATA_TABLE, get_latlng
from seafevents.repo_metadata.seafile_ai_api import SeafileAIAPI
from seafevents.seasearch.utils.seasearch_api import SeaSearchAPI
from seafevents.utils import parse_bool
from seafevents.seasearch.index_store.repo_image_index import RepoImageIndex
from seafevents.seasearch.utils.constants import REPO_IMAGE_INDEX_PREFIX

logger = logging.getLogger(__name__)

Expand All @@ -20,6 +26,8 @@ class SlowTaskHandler(object):

def __init__(self, config):
self.metadata_server_api = MetadataServerAPI('seafevents')
self.seafile_ai_api = SeafileAIAPI()
self.seasearch_api = None

self.should_stop = threading.Event()
self.mq_server = '127.0.0.1'
Expand All @@ -29,6 +37,7 @@ def __init__(self, config):
self._parse_config(config)

self.mq = get_mq(self.mq_server, self.mq_port, self.mq_password)
self.repo_image_index = RepoImageIndex(self.seasearch_api)

def _parse_config(self, config):
redis_section_name = 'REDIS'
Expand All @@ -46,6 +55,15 @@ def _parse_config(self, config):
if config.has_section(metadata_section_name):
self.worker_num = get_opt_from_conf_or_env(config, metadata_section_name, key_index_workers, default=3)

seasearch_section_name = 'SEASEARCH'
seasearch_key_enabled = 'enabled'
if config.has_section(seasearch_section_name):
enabled = get_opt_from_conf_or_env(config, seasearch_section_name, seasearch_key_enabled, default=False)
if parse_bool(enabled):
seasearch_url = get_opt_from_conf_or_env(config, seasearch_section_name, 'seasearch_url')
seasearch_token = get_opt_from_conf_or_env(config, seasearch_section_name, 'seasearch_token')
self.seasearch_api = SeaSearchAPI(seasearch_url, seasearch_token)

@property
def tname(self):
return threading.current_thread().name
Expand Down Expand Up @@ -82,48 +100,121 @@ def worker_handler(self):

def slow_task_handler(self, repo_id, data):
task_type = data.get('task_type')
if task_type == 'location_extract':
self.extract_image_location(repo_id, data)
if task_type == 'image_info_extract':
self.extract_image_info(repo_id, data)
elif task_type == 'delete_image_index':
self.delete_image_index(repo_id, data)
elif task_type == 'modify_image_index':
self.modify_image_index(repo_id, data)

def delete_image_index(self, repo_id, data):
logger.info('%s start delete image index repo %s' % (threading.currentThread().getName(), repo_id))

try:
paths = data.get('paths')
if paths:
repo_image_index_name = REPO_IMAGE_INDEX_PREFIX + repo_id
self.repo_image_index.delete_images(repo_image_index_name, paths)
except Exception as e:
logger.exception('repo: %s, delete image index error: %s', repo_id, e)

def modify_image_index(self, repo_id, data):
logger.info('%s start modify image index repo %s' % (threading.currentThread().getName(), repo_id))

try:
obj_ids = data.get('obj_ids')
sql = f'SELECT `{METADATA_TABLE.columns.id.name}`, `{METADATA_TABLE.columns.obj_id.name}`, `{METADATA_TABLE.columns.parent_dir.name}`, `{METADATA_TABLE.columns.file_name.name}`, `{METADATA_TABLE.columns.image_feature.name}` FROM `{METADATA_TABLE.name}` WHERE `{METADATA_TABLE.columns.obj_id.name}` IN ('
parameters = []

for obj_id in obj_ids:
sql += '?, '
parameters.append(obj_id)

def extract_image_location(self, repo_id, data):
logger.info('%s start extract image location repo %s' % (threading.currentThread().getName(), repo_id))
if not parameters:
return
sql = sql.rstrip(', ') + ');'
query_result = self.metadata_server_api.query_rows(repo_id, sql, parameters).get('results', [])
if not query_result:
return

images_data = []
for row in query_result:
parent_dir = row[METADATA_TABLE.columns.parent_dir.name]
file_name = row[METADATA_TABLE.columns.file_name.name]
image_feature = row[METADATA_TABLE.columns.image_feature.name]
images_data.append({
'path': os.path.join(parent_dir, file_name),
'embedding': json.loads(image_feature) if image_feature else '',
})

if images_data:
repo_image_index_name = REPO_IMAGE_INDEX_PREFIX + repo_id
self.repo_image_index.create_index_if_missing(repo_image_index_name)
paths = [image['path'] for image in images_data]
self.repo_image_index.delete_images(repo_image_index_name, paths)
self.repo_image_index.add_images(repo_image_index_name, images_data)
except Exception as e:
logger.exception('repo: %s, modify image index error: %s', repo_id, e)

def extract_image_info(self, repo_id, data):
logger.info('%s start extract image info repo %s' % (threading.currentThread().getName(), repo_id))

try:
obj_ids = data.get('obj_ids')
commit_id = data.get('commit_id')
sql = f'SELECT `{METADATA_TABLE.columns.id.name}`, `{METADATA_TABLE.columns.obj_id.name}` FROM `{METADATA_TABLE.name}` WHERE `{METADATA_TABLE.columns.obj_id.name}` IN ('
sql = f'SELECT `{METADATA_TABLE.columns.id.name}`, `{METADATA_TABLE.columns.obj_id.name}`, `{METADATA_TABLE.columns.parent_dir.name}`, `{METADATA_TABLE.columns.file_name.name}` FROM `{METADATA_TABLE.name}` WHERE `{METADATA_TABLE.columns.obj_id.name}` IN ('
parameters = []

obj_id_to_extract_info = {}
updated_rows = []
for obj_id in obj_ids:
obj_id_to_extract_info[obj_id] = get_latlng(repo_id, commit_id, obj_id)
obj_id_to_extract_info[obj_id] = {
'location': get_latlng(repo_id, obj_id)
}
sql += '?, '
parameters.append(obj_id)

embeddings = self.seafile_ai_api.images_embedding(repo_id, obj_ids).get('data', [])
for embedding in embeddings:
obj_id = embedding['obj_id']
if obj_id in obj_id_to_extract_info:
obj_id_to_extract_info[obj_id]['embedding'] = embedding['embedding']

if not parameters:
return
sql = sql.rstrip(', ') + ');'
query_result = self.metadata_server_api.query_rows(repo_id, sql, parameters).get('results', [])
if not query_result:
return

images_data = []
for row in query_result:
row_id = row[METADATA_TABLE.columns.id.name]
obj_id = row[METADATA_TABLE.columns.obj_id.name]
lat, lng = obj_id_to_extract_info.get(obj_id)
parent_dir = row[METADATA_TABLE.columns.parent_dir.name]
file_name = row[METADATA_TABLE.columns.file_name.name]
lat, lng = obj_id_to_extract_info.get(obj_id, {}).get('location')
embedding = obj_id_to_extract_info.get(obj_id, {}).get('embedding')
update_row = {
METADATA_TABLE.columns.id.name: row_id,
METADATA_TABLE.columns.location.name: {'lng': lng, 'lat': lat},
METADATA_TABLE.columns.image_feature.name: json.dumps(embedding) if embedding else '',
}
updated_rows.append(update_row)
images_data.append({
'path': os.path.join(parent_dir, file_name),
'embedding': embedding,
})

if len(updated_rows) >= METADATA_OP_LIMIT:
self.metadata_server_api.update_rows(repo_id, METADATA_TABLE.id, updated_rows)
updated_rows = []

if images_data:
repo_image_index_name = REPO_IMAGE_INDEX_PREFIX + repo_id
self.repo_image_index.create_index_if_missing(repo_image_index_name)
self.repo_image_index.add_images(repo_image_index_name, images_data)
if not updated_rows:
return
self.metadata_server_api.update_rows(repo_id, METADATA_TABLE.id, updated_rows)
except Exception as e:
logger.exception('repo: %s, update metadata location error: %s', repo_id, e)
logger.exception('repo: %s, update metadata image info error: %s', repo_id, e)
Loading
Loading