Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix:vectorize in batches. #980

Merged
merged 1 commit into from
Apr 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 14 additions & 3 deletions pypi/data-processing/src/data_store_process/minio_store_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -407,7 +407,7 @@ async def text_manipulate(
return {"status": 400, "message": str(ex), "data": traceback.format_exc()}


def text_manipulate_retry(req_json, pool):
async def text_manipulate_retry(req_json, pool):
task_id = req_json.get("id")
creator = req_json.get("creator")
log_id = ulid.ulid()
Expand Down Expand Up @@ -470,7 +470,7 @@ def text_manipulate_retry(req_json, pool):
]
)
)
result = _text_manipulate_retry_for_document(
result = await _text_manipulate_retry_for_document(
document=document,
task_info=task_info_dict,
log_id=log_id,
Expand Down Expand Up @@ -937,7 +937,7 @@ def _insert_log_info(id, task_id, execute_type, creator, pool):
return {"status": 400, "message": str(ex), "data": traceback.format_exc()}


def _text_manipulate_retry_for_document(document, task_info, log_id, pool, creator):
async def _text_manipulate_retry_for_document(document, task_info, log_id, pool, creator):
file_name = document.get("file_name")
task_id = task_info.get("id")
document_id = document.get("id")
Expand Down Expand Up @@ -1025,6 +1025,16 @@ def _text_manipulate_retry_for_document(document, task_info, log_id, pool, creat
task_id=task_id,
create_user=creator,
)
elif file_extension == "web":
# 处理.web文件
result = await web_handle.web_manipulate(
file_name=file_name,
document_id=item.get("document_id"),
support_type=support_type,
conn_pool=pool,
task_id=id,
create_user=req_json["creator"],
)

# 将下载的本地文件删除
_remove_local_file(file_name)
Expand All @@ -1042,6 +1052,7 @@ def _text_manipulate_retry_for_document(document, task_info, log_id, pool, creat
file_name=file_name,
all_document_for_process=document_chunk_dict.get("data"),
support_type=support_type,
progress=int(document.get("progress")),
conn_pool=pool,
create_user=creator,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,35 @@ def update_document_progress(req_json, pool):
return res


def update_document_status_and_progress(req_json, pool):
"""Update the status and progress with id"""
now = date_time_utils.now_str()
program = "文件处理完成-修改"

params = {
"id": req_json["id"],
"status": req_json["status"],
"end_time": now,
"progress": req_json["progress"],
"update_datetime": now,
"update_program": program,
}

sql = """
update public.data_process_task_document set
status = %(status)s,
end_time = %(end_time)s,
progress = %(progress)s,
update_datetime = %(update_datetime)s,
update_program = %(update_program)s
where
id = %(id)s
""".strip()

res = postgresql_pool_client.execute_update(pool, sql, params)
return res


def list_file_by_task_id(req_json, pool):
"""info with id"""
params = {"task_id": req_json["task_id"]}
Expand Down
170 changes: 153 additions & 17 deletions pypi/data-processing/src/file_handle/common_handle.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,12 @@


def text_manipulate(
all_document_for_process, file_name, support_type, conn_pool, create_user
all_document_for_process,
file_name,
support_type,
conn_pool,
create_user,
progress=0
):
"""Manipulate the text content.

Expand All @@ -63,7 +68,7 @@ def text_manipulate(
conn_pool=conn_pool,
)

text_process_success_num = 0
text_process_success_num = progress
for document in all_document_for_process:
document_chunk_id = document.get("id")
# Clean the data such as removing invisible characters.
Expand Down Expand Up @@ -116,11 +121,6 @@ def text_manipulate(
if qa_response.get("status") != 200:
return qa_response

# 文件处理成功,更新data_process_task_document中的文件状态
_updata_document_status_and_end_time(
id=document_id, status="success", conn_pool=conn_pool
)

if support_type_map.get("qa_split"):
# 是否选择了QA拆分
qa_list_dict = support_type_map.get("qa_split")
Expand Down Expand Up @@ -196,6 +196,13 @@ def text_manipulate(
file_name=file_name_csv, phase_value="final", data=qa_data_dict
)

_update_document_status_and_progress(
id=document_id,
status="success",
progress=100,
conn_pool=conn_pool
)

logger.debug(f"{log_tag_const.COMMON_HANDLE} Finish manipulating the text")
return {
"status": 200,
Expand Down Expand Up @@ -225,13 +232,25 @@ def text_manipulate(
file_name=file_name_csv, phase_value="final", data=chunk_data_dict
)

_update_document_status_and_progress(
id=document_id,
status="success",
progress=100,
conn_pool=conn_pool
)

logger.debug(f"{log_tag_const.COMMON_HANDLE} Finish manipulating the text")
return {
"status": 200,
"message": "",
"data": "",
}

# 文件处理成功,更新data_process_task_document中的文件状态
_updata_document_status_and_end_time(
id=document_id, status="success", conn_pool=conn_pool
)

return {"status": 200, "message": "", "data": ""}
except Exception as ex:
logger.error(
Expand Down Expand Up @@ -914,6 +933,7 @@ def _qa_split(
):
qa_list_dict = support_type_map.get("qa_split")
llm_config = qa_list_dict.get("llm_config")
remove_duplicate_config = qa_list_dict.get("remove_duplicate_config")

# 更新chunk状态为开始
_update_document_chunk_status_and_start_time(
Expand All @@ -937,6 +957,7 @@ def _qa_split(
id=document_id, status="fail", conn_pool=conn_pool
)
else:
qa_list = []
# 将QA数据存入表中
qa_data = qa_response.get("data")
for _, item in enumerate(qa_data):
Expand All @@ -955,6 +976,34 @@ def _qa_split(
qa_insert_item, pool=conn_pool
)

qa_list.append(qa_insert_item)

# 是否需要进行去重
if remove_duplicate_config:
for qa in qa_list:
embedding_response = _embedding_qa(
qa_list=[qa],
remove_duplicate_config=remove_duplicate_config,
conn_pool=conn_pool
)

if embedding_response.get("status") != 200:
# 处理失败
# 更新data_process_task_document_chunk中的状态
_updata_document_chunk_status_and_end_time(
id=document_chunk_id,
update_user=create_user,
status="fail",
conn_pool=conn_pool,
)

# 更新data_process_task_document中的文件状态
_updata_document_status_and_end_time(
id=document_id, status="fail", conn_pool=conn_pool
)

return embedding_response

# 更新data_process_task_document_chunk中的状态
_updata_document_chunk_status_and_end_time(
id=document_chunk_id,
Expand All @@ -965,6 +1014,9 @@ def _qa_split(

# 更新文件处理进度
progress = int(text_process_success_num / document_chunk_size * 100)
if text_process_success_num == document_chunk_size:
progress = 99

_updata_document_progress(
id=document_id,
progress=progress,
Expand Down Expand Up @@ -994,7 +1046,7 @@ def _generate_qa_list(content, llm_config):

# Generate the QA list.
qa_list = []
if llm_spec_info.get("data").get("provider").get("worker"):
if llm_config.get("provider") == "worker":
# get base url for configmap
base_url = model_cr.get_worker_base_url_k8s_configmap(
name=config.k8s_default_config, namespace=config.k8s_pod_namespace
Expand Down Expand Up @@ -1190,6 +1242,26 @@ def _updata_document_progress(id, progress, update_user, conn_pool):
return {"status": 1000, "message": str(ex), "data": traceback.format_exc()}


def _update_document_status_and_progress(id, status, progress, conn_pool):
try:
document_update_item = {"id": id, "status": status, "progress": progress}
data_process_document_db_operate.update_document_status_and_progress(
document_update_item, pool=conn_pool
)

return {"status": 200, "message": "", "data": ""}
except Exception as ex:
logger.error(
"".join(
[
f"{log_tag_const.COMMON_HANDLE} update document status ",
f"\n{traceback.format_exc()}",
]
)
)
return {"status": 1000, "message": str(ex), "data": traceback.format_exc()}


def _update_document_chunk_status_and_start_time(id, update_user, conn_pool):
try:
now = date_time_utils.now_str()
Expand Down Expand Up @@ -1292,8 +1364,8 @@ def _qa_remove_duplicate(qa_list, remove_duplicate_config, conn_pool):
provider = remove_duplicate_config.get("embedding_provider")
similarity = float(remove_duplicate_config.get("similarity"))

# llms cr 中模型相关信息
llm_spec_info = model_cr.get_spec_for_embedding_k8s_cr(name=name, namespace=namespace)
# embedding cr 中模型相关信息
embedding_spec_info = model_cr.get_spec_for_embedding_k8s_cr(name=name, namespace=namespace)

if provider == "worker":
# get base url for configmap
Expand All @@ -1319,11 +1391,11 @@ def _qa_remove_duplicate(qa_list, remove_duplicate_config, conn_pool):
)

remove_duplicate_loader = QARemoveDuplicate(embeddings=qa_embeddings, pool=conn_pool)
return remove_duplicate_loader.qa_remove_duplicate(qa_list, similarity)
return remove_duplicate_loader.remove_duplicate_qa_data(qa_list, similarity)
else:
endpoint = llm_spec_info.get("data").get("provider").get("endpoint")
endpoint = embedding_spec_info.get("data").get("provider").get("endpoint")
base_url = endpoint.get("url")
llm_type = llm_spec_info.get("data").get("type")
embedding_type = embedding_spec_info.get("data").get("type")

logger.debug(
"".join(
Expand All @@ -1332,19 +1404,83 @@ def _qa_remove_duplicate(qa_list, remove_duplicate_config, conn_pool):
f"name: {name}\n",
f"namespace: {namespace}\n",
f"model: {model}\n",
f"llm_type: {llm_type}\n",
f"embedding_type: {embedding_type}\n",
]
)
)

if embedding_type == "openai":
qa_embeddings = OpenAIEmbeddings(
api_key="fake",
base_url=base_url,
model=model,
)

remove_duplicate_loader = QARemoveDuplicate(embeddings=qa_embeddings, pool=conn_pool)
return remove_duplicate_loader.remove_duplicate_qa_data(qa_list, similarity)
else:
return {"status": 1000, "message": f"暂时不支持{embedding_type}类型的向量化模型模型", "data": ""}


def _embedding_qa(qa_list, remove_duplicate_config, conn_pool):
name = remove_duplicate_config.get("embedding_name")
namespace = remove_duplicate_config.get("embedding_namespace")
model = remove_duplicate_config.get("embedding_model")
provider = remove_duplicate_config.get("embedding_provider")

# embeddings cr 中模型相关信息
embedding_spec_info = model_cr.get_spec_for_embedding_k8s_cr(name=name, namespace=namespace)

if provider == "worker":
# get base url for configmap
base_url = model_cr.get_worker_base_url_k8s_configmap(
name=config.k8s_default_config, namespace=config.k8s_pod_namespace
)
logger.debug(
"".join(
[
f"worker embedding \n",
f"name: {name}\n",
f"namespace: {namespace}\n",
f"model: {model}\n",
f"base_url: {base_url}\n",
]
)
)

qa_embeddings = OpenAIEmbeddings(
api_key="fake",
base_url=base_url,
model=model,
)

remove_duplicate_loader = QARemoveDuplicate(embeddings=qa_embeddings, pool=conn_pool)
return remove_duplicate_loader.embedding_qa_data(qa_list)
else:
endpoint = embedding_spec_info.get("data").get("provider").get("endpoint")
base_url = endpoint.get("url")
embedding_type = embedding_spec_info.get("data").get("type")

logger.debug(
"".join(
[
f"3rd_party embedding \n",
f"name: {name}\n",
f"namespace: {namespace}\n",
f"model: {model}\n",
f"embedding_type: {embedding_type}\n",
]
)
)

if llm_type == "openai":
if embedding_type == "openai":
qa_embeddings = OpenAIEmbeddings(
api_key="fake",
base_url=base_url,
model=model,
)

remove_duplicate_loader = QARemoveDuplicate(embeddings=qa_embeddings, pool=conn_pool)
return remove_duplicate_loader.qa_remove_duplicate(qa_list, similarity)
return remove_duplicate_loader.embedding_qa_data(qa_list)
else:
return {"status": 1000, "message": f"暂时不支持{llm_type}类型的向量化模型模型", "data": ""}
return {"status": 1000, "message": f"暂时不支持{embedding_type}类型的向量化模型模型", "data": ""}
Loading
Loading