Skip to content

Commit

Permalink
Merge pull request #71 from breezedeus/pytorch
Browse files Browse the repository at this point in the history
download models from different oss urls, based on env vars
  • Loading branch information
breezedeus committed Oct 9, 2023
2 parents 160a866 + 31d4cc0 commit b064e3a
Show file tree
Hide file tree
Showing 12 changed files with 146 additions and 57 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ package:
rm -rf build
python setup.py sdist bdist_wheel

VERSION = 1.2.3.4
VERSION = 1.2.3.5
upload:
python -m twine upload dist/cnstd-$(VERSION)* --verbose

Expand Down
29 changes: 21 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,19 @@
</div>

# CnSTD
# Update 2023.06.30:发布 V1.2.3
# Update 2023.10.09:发布 V1.2.3.5

主要变更:
* 修复了模型文件自动下载的功能。HuggingFace似乎对下载文件的逻辑做了调整,导致之前版本的自动下载失败,当前版本已修复。但由于HuggingFace国内被墙,国内下载仍需 **梯子(VPN)**
* 更新了各个依赖包的版本号。

# Update 2023.06.20:
* 支持基于环境变量 `CNSTD_DOWNLOAD_SOURCE` 的取值,来决定不同的模型下载路径,默认使用国内OSS地址。
* `LayoutAnalyzer` 中增加了参数 `model_categories``model_arch_yaml`,用于指定模型的类别名称列表和模型架构。

...

# Update 2023.06.30:发布 V1.2.3

主要变更:

* 基于新标注的数据,重新训练了 **MFD YoloV7** 模型,目前新模型已部署到 [P2T网页版](https://p2t.behye.com) 。具体说明见:[Pix2Text (P2T) 新版公式检测模型 | Breezedeus.com](https://www.breezedeus.com/article/p2t-mfd-20230613)
* 之前的 MFD YoloV7 模型已开放给星球会员下载,具体说明见:[P2T YoloV7 数学公式检测模型开放给星球会员下载 | Breezedeus.com](https://www.breezedeus.com/article/p2t-yolov7-for-zsxq-20230619)
* 增加了一些Label Studio相关的脚本,见 [scripts](scripts) 。如:利用 CnSTD 自带的 MFD 模型对目录中的图片进行公式检测后生成可导入到Label Studio中的JSON文件;以及,Label Studio标注后把导出的JSON文件转换成训练 MFD 模型所需的数据格式。注意,MFD 模型的训练代码在 [yolov7](https://github.com/breezedeus/yolov7)`dev` branch)中。
Expand Down Expand Up @@ -358,13 +362,15 @@ class LayoutAnalyzer(object):
self,
model_name: str = 'mfd', # 'layout' or 'mfd'
*,
model_type: str = 'yolov7_tiny',
model_type: str = 'yolov7_tiny', # 当前支持 [`yolov7_tiny`, `yolov7`]'
model_backend: str = 'pytorch',
model_categories: Optional[List[str]] = None,
model_fp: Optional[str] = None,
model_arch_yaml: Optional[str] = None,
root: Union[str, Path] = data_dir(),
device: str = 'cpu',
**kwargs,
):
)
```
其中的参数含义如下:
Expand All @@ -375,8 +381,12 @@ class LayoutAnalyzer(object):
- `model_backend`: 字符串类型,表示backend。当前仅支持: 'pytorch';默认值:'pytorch'
- `model_categories`: 模型的检测类别名称。默认值:None,表示基于 `model_name` 自动决定
- `model_fp`: 字符串类型,表示模型文件的路径。默认值:`None`,表示使用默认的文件路径
- `model_arch_yaml`: 架构文件路径,例如 'yolov7-mfd.yaml';默认值为 None,表示将自动选择。
- `root`: 字符串或`Path`类型,表示模型文件所在的根目录。
- Linux/Mac下默认值为 `~/.cnstd`,表示模型文件所处文件夹类似 `~/.cnstd/1.2/analysis`
- Windows下默认值为 `C:/Users/<username>/AppData/Roaming/cnstd`
Expand Down Expand Up @@ -502,18 +512,21 @@ Usage: cnstd analyze [OPTIONS]
对给定图片进行 MFD 或者 版面分析。
Options:
-m, --model-name [mfd|layout] 模型类型。`mfd` 表示数学公式检测,`layout`
-m, --model-name TEXT 模型类型。`mfd` 表示数学公式检测,`layout`
表示版面分析;默认为:`mfd`
-t, --model-type TEXT 模型类型。当前支持 [`yolov7_tiny`, `yolov7`]
-b, --model-backend [pytorch|onnx]
模型后端架构。当前仅支持 `pytorch`
-c, --model-categories TEXT 模型的检测类别名称(","分割)。默认值:None,表示基于 `model_name`
自动决定
-p, --model-fp TEXT 使用训练好的模型。默认为 `None`,表示使用系统自带的预训练模型
-y, --model-arch-yaml TEXT 模型的配置文件路径
--device TEXT cuda device, i.e. 0 or 0,1,2,3 or cpu
-i, --img-fp TEXT 待分析的图片路径或图片目录
-o, --output-fp TEXT 分析结果输出的图片路径。默认为 `None`,会存储在当前文件夹,文件名称为输入文件名称
前面增加`out-`;如输入文件名为 `img.jpg`, 输出文件名即为 `out-
img.jpg`;如果输入为目录,则此路径也应该是一个目录,会将输出文件存储在此目录下
--resized-shape INTEGER 分析时把图片resize到此大小再进行。默认为 `700`
--resized-shape INTEGER 分析时把图片resize到此大小再进行。默认为 `608`
--conf-thresh FLOAT Confidence Threshold。默认值为 `0.25`
--iou-thresh FLOAT IOU threshold for NMS。默认值为 `0.45`
-h, --help Show this message and exit.
Expand Down
7 changes: 7 additions & 0 deletions RELEASE.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
# Release Notes

# Update 2023.10.09:发布 V1.2.3.5

主要变更:

* 支持基于环境变量 `CNSTD_DOWNLOAD_SOURCE` 的取值,来决定不同的模型下载路径。
* `LayoutAnalyzer` 中增加了参数 `model_categories``model_arch_yaml`,用于指定模型的类别名称列表和模型架构。

# Update 2023.09.23:发布 V1.2.3.4

主要变更:
Expand Down
2 changes: 1 addition & 1 deletion cnstd/__version__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,4 @@
# specific language governing permissions and limitations
# under the License.

__version__ = '1.2.3.4'
__version__ = '1.2.3.5'
37 changes: 26 additions & 11 deletions cnstd/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,9 +197,7 @@ def _vis_bool(img, fp):
default=None,
help='使用训练好的模型。默认为 `None`,表示使用系统自带的预训练模型',
)
@click.option(
"-r", "--rotated-bbox", is_flag=True, help="是否检测带角度(非水平和垂直)的文本框"
)
@click.option("-r", "--rotated-bbox", is_flag=True, help="是否检测带角度(非水平和垂直)的文本框")
@click.option(
"--resized-shape",
type=str,
Expand Down Expand Up @@ -323,7 +321,7 @@ def resave_model_file(
@click.option(
'-m',
'--model-name',
type=click.Choice(['mfd', 'layout']),
type=str,
default='mfd',
help='模型类型。`mfd` 表示数学公式检测,`layout` 表示版面分析;默认为:`mfd`',
)
Expand All @@ -341,13 +339,21 @@ def resave_model_file(
default='pytorch',
help='模型后端架构。当前仅支持 `pytorch`',
)
@click.option(
'-c',
'--model-categories',
type=str,
default=None,
help='模型的检测类别名称(","分割)。默认值:None,表示基于 `model_name` 自动决定',
)
@click.option(
'-p',
'--model-fp',
type=str,
default=None,
help='使用训练好的模型。默认为 `None`,表示使用系统自带的预训练模型',
)
@click.option('-y', '--model-arch-yaml', type=str, default=None, help='模型的配置文件路径')
@click.option('--device', default='', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')
@click.option(
'-i', '--img-fp', type=str, default='./examples/mfd/zh.jpg', help='待分析的图片路径或图片目录'
Expand All @@ -358,11 +364,11 @@ def resave_model_file(
type=str,
default=None,
help='分析结果输出的图片路径。默认为 `None`,会存储在当前文件夹,文件名称为输入文件名称前面增加`out-`;'
'如输入文件名为 `img.jpg`, 输出文件名即为 `out-img.jpg`;'
'如果输入为目录,则此路径也应该是一个目录,会将输出文件存储在此目录下',
'如输入文件名为 `img.jpg`, 输出文件名即为 `out-img.jpg`;'
'如果输入为目录,则此路径也应该是一个目录,会将输出文件存储在此目录下',
)
@click.option(
"--resized-shape", type=int, default=700, help='分析时把图片resize到此大小再进行。默认为 `700`',
"--resized-shape", type=int, default=608, help='分析时把图片resize到此大小再进行。默认为 `608`',
)
@click.option(
'--conf-thresh', type=float, default=0.25, help='Confidence Threshold。默认值为 `0.25`'
Expand All @@ -374,7 +380,9 @@ def layout_analyze(
model_name,
model_type,
model_backend,
model_categories,
model_fp,
model_arch_yaml,
device,
img_fp,
output_fp,
Expand All @@ -383,11 +391,18 @@ def layout_analyze(
iou_thresh,
):
"""对给定图片进行 MFD 或者 版面分析。"""
if not os.path.exists(img_fp):
raise FileNotFoundError(img_fp)

if model_categories is not None:
model_categories = model_categories.split(',')
analyzer = LayoutAnalyzer(
model_name=model_name,
model_type=model_type,
model_backend=model_backend,
model_categories=model_categories,
model_fp=model_fp,
model_arch_yaml=model_arch_yaml,
device=device,
)

Expand All @@ -400,11 +415,11 @@ def layout_analyze(
elif os.path.isdir(img_fp):
fn_list = glob.glob1(img_fp, '*g')
input_fp_list = [os.path.join(img_fp, fn) for fn in fn_list]
assert output_fp is not None, 'output_fp should NOT be None when img_fp is a directory'
assert (
output_fp is not None
), 'output_fp should NOT be None when img_fp is a directory'
os.makedirs(output_fp, exist_ok=True)
output_fp_list = [
os.path.join(output_fp, 'analysis-' + fn) for fn in fn_list
]
output_fp_list = [os.path.join(output_fp, 'analysis-' + fn) for fn in fn_list]

for input_fp, output_fp in zip(input_fp_list, output_fp_list):
out = analyzer.analyze(
Expand Down
7 changes: 7 additions & 0 deletions cnstd/consts.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
# specific language governing permissions and limitations
# under the License.

import os
import logging
from pathlib import Path
from typing import Tuple, Set, Dict, Any, Optional, Union
Expand All @@ -43,6 +44,8 @@
# 如: __version__ = '1.0.*',对应的 MODEL_VERSION 都是 '1.0'
MODEL_VERSION = '.'.join(__version__.split('.', maxsplit=2)[:2])
VOCAB_FP = Path(__file__).parent.parent / 'label_cn.txt'
# Which OSS source will be used for downloading model files, 'CN' or 'HF'
DOWNLOAD_SOURCE = os.environ.get('CNSTD_DOWNLOAD_SOURCE', 'CN')

MODEL_CONFIGS: Dict[str, Dict[str, Any]] = {
'db_resnet50': {
Expand Down Expand Up @@ -113,13 +116,17 @@

HF_HUB_REPO_ID = "breezedeus/cnstd-cnocr-models"
HF_HUB_SUBFOLDER = "models/cnstd/%s" % MODEL_VERSION
CN_OSS_ENDPOINT = (
"https://sg-models.oss-cn-beijing.aliyuncs.com/cnstd/%s/" % MODEL_VERSION
)


def format_hf_hub_url(url: str) -> dict:
return {
'repo_id': HF_HUB_REPO_ID,
'subfolder': HF_HUB_SUBFOLDER,
'filename': url,
'cn_oss': CN_OSS_ENDPOINT,
}


Expand Down
4 changes: 2 additions & 2 deletions cnstd/detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from PIL import Image
import numpy as np

from .consts import MODEL_VERSION, AVAILABLE_MODELS
from .consts import MODEL_VERSION, AVAILABLE_MODELS, DOWNLOAD_SOURCE
from .model import gen_model
from .model.core import DetectionPredictor
from .utils import (
Expand Down Expand Up @@ -144,7 +144,7 @@ def _assert_and_prepare_model_files(self, model_fp, root):
% ((self._model_name, self._model_backend),)
)
url = AVAILABLE_MODELS.get_url(self._model_name, self._model_backend)
get_model_file(url, self._model_dir) # download the .zip file and unzip
get_model_file(url, self._model_dir, download_source=DOWNLOAD_SOURCE) # download the .zip file and unzip
fps = glob('%s/%s*.ckpt' % (self._model_dir, self._model_file_prefix))

self._model_fp = fps[0]
Expand Down
4 changes: 2 additions & 2 deletions cnstd/ppocr/angle_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
import cv2
import numpy as np

from ..consts import MODEL_VERSION, ANGLE_CLF_MODELS, ANGLE_CLF_SPACE
from ..consts import MODEL_VERSION, ANGLE_CLF_MODELS, ANGLE_CLF_SPACE, DOWNLOAD_SOURCE
from ..utils import data_dir, get_model_file
from .postprocess import build_post_process
from .utility import (
Expand Down Expand Up @@ -89,7 +89,7 @@ def _assert_and_prepare_model_files(self, model_fp, root):
)
url = ANGLE_CLF_MODELS[(self._model_name, self._model_backend)]['url']

get_model_file(url, self._model_dir) # download the .zip file and unzip
get_model_file(url, self._model_dir, download_source=DOWNLOAD_SOURCE) # download the .zip file and unzip

self._model_fp = model_fp
logger.info('use model: %s' % self._model_fp)
Expand Down
4 changes: 2 additions & 2 deletions cnstd/ppocr/pp_detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
import numpy as np

from .consts import PP_SPACE
from ..consts import MODEL_VERSION, AVAILABLE_MODELS
from ..consts import MODEL_VERSION, AVAILABLE_MODELS, DOWNLOAD_SOURCE
from ..utils import data_dir, get_model_file, sort_boxes, get_resized_shape
from .utility import (
get_image_file_list,
Expand Down Expand Up @@ -129,7 +129,7 @@ def _assert_and_prepare_model_files(self, model_fp, root):
)
url = AVAILABLE_MODELS.get_url(self._model_name, self._model_backend)

get_model_file(url, self._model_dir) # download the .zip file and unzip
get_model_file(url, self._model_dir, download_source=DOWNLOAD_SOURCE) # download the .zip file and unzip

self._model_fp = model_fp
logger.info('use model: %s' % self._model_fp)
Expand Down
4 changes: 3 additions & 1 deletion cnstd/ppocr/utility.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,9 @@ def create_predictor(model_fp, mode, logger):
model_file_path = model_fp
if not os.path.exists(model_file_path):
raise ValueError("not find model file path {}".format(model_file_path))
sess = ort.InferenceSession(model_file_path, providers=['AzureExecutionProvider', 'CPUExecutionProvider'])
sess = ort.InferenceSession(
model_file_path, providers=ort.get_available_providers()
)
return sess, sess.get_inputs()[0], None, None


Expand Down
Loading

0 comments on commit b064e3a

Please sign in to comment.