Skip to content

Commit

Permalink
Merge pull request #48 from breezedeus/pytorch
Browse files Browse the repository at this point in the history
V1.2: suppot ppocr models
  • Loading branch information
breezedeus authored Jul 6, 2022
2 parents 2e7cef3 + 6bae18e commit 15fa3d1
Show file tree
Hide file tree
Showing 26 changed files with 2,772 additions and 344 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ demo:
package:
python setup.py sdist bdist_wheel

VERSION = 1.1.2
VERSION = 1.2
upload:
python -m twine upload dist/cnstd-$(VERSION)* --verbose

Expand Down
85 changes: 57 additions & 28 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,13 @@

作者也维护 **知识星球** [**CnOCR/CnSTD私享群**](https://t.zsxq.com/FEYZRJQ),欢迎加入。**知识星球私享群**会陆续发布一些CnOCR/CnSTD相关的私有资料,包括**更详细的训练教程****未公开的模型**,使用过程中遇到的难题解答等。本群也会发布OCR/STD相关的最新研究资料。



**v1.0.0** 版本开始,**cnstd** 从之前基于 MXNet 实现转为基于 **PyTorch** 实现。新模型的训练合并了 **ICPR MTWI 2018****ICDAR RCTW-17****ICDAR2019-LSVT** 三个数据集,包括了 **`46447`** 个训练样本,和 **`1534`** 个测试样本。

相较于 V1.0.0, **V1.1.0** 的变化主要包括
相较于之前版本, 新版本的变化主要包括

* bugfixes:修复了训练过程中发现的诸多问题;
* 检测主类 **`CnStd`** 初始化接口略有调整,去掉了参数 `model_epoch`
* backbone 结构中加入了对 **ShuffleNet** 的支持;
* 优化了训练中的超参数取值,提升了模型检测精度;
* 提供了更多的预训练模型可供选择,最小模型降至 **7.5M** 文件大小。
* 加入了对 [**PaddleOCR**](https://github.com/PaddlePaddle/PaddleOCR) 检测模型的支持;
* 部分调整了检测结果中 `box` 的表达方式,统一为 `4` 个点的坐标值;
* 修复了已知bugs。

如需要识别文本框中的文字,可以结合 **OCR** 工具包 **[cnocr](https://github.com/breezedeus/cnocr)** 一起使用。

Expand Down Expand Up @@ -45,7 +41,13 @@ pip install cnstd -i https://pypi.doubanio.com/simple

## 已有模型

当前版本(**V1.1.0**)的文字检测模型使用的是 **[DBNet](https://github.com/MhLiao/DB)**,相较于 V0.1 使用的 [PSENet](https://github.com/whai362/PSENet) 模型, DBNet 的检测耗时几乎下降了一个量级,同时检测精度也得到了极大的提升。
cnstd 从 **V1.2** 开始,可直接使用的模型包含两类:1)cnstd 自己训练的模型,通常会包含 PyTorch 和 ONNX 版本;2)从其他ocr引擎搬运过来的训练好的外部模型,ONNX化后用于 cnstd 中。

直接使用的模型都放在 [**cnstd-cnocr-models**](https://huggingface.co/breezedeus/cnstd-cnocr-models) 项目中,可免费下载使用。

### 1. cnstd 自己训练的模型

当前版本(Since **V1.1.0**)的文字检测模型使用的是 [**DBNet**](https://github.com/MhLiao/DB),相较于 V0.1 使用的 [PSENet](https://github.com/whai362/PSENet) 模型, DBNet 的检测耗时几乎下降了一个量级,同时检测精度也得到了极大的提升。

目前包含以下已训练好的模型:

Expand All @@ -63,11 +65,23 @@ pip install cnstd -i https://pypi.doubanio.com/simple
相对于两个基于 **ResNet** 的模型,基于 **MobileNet****ShuffleNet** 的模型体积更小,速度更快,建议在轻量级场景使用。

### 2. 外部模型

以下模型是 [**PaddleOCR**](https://github.com/PaddlePaddle/PaddleOCR) 中模型的 **ONNX** 版本,所以不会依赖 **PaddlePaddle** 相关工具包,故而也不支持基于这些模型在自己的领域数据上继续精调模型。这些模型支持检测**竖排文字**

| `model_name` | PyTorch 版本 | ONNX 版本 | 支持检测的语言 | 模型文件大小 |
| --------------- | ---------- | ------- | ---------- | ------ |
| ch_PP-OCRv3_det | X || 简体中问、英文、数字 | 2.3 M |
| ch_PP-OCRv2_det | X || 简体中问、英文、数字 | 2.2 M |
| en_PP-OCRv3_det | X || **英文**、数字 | 2.3 M |

更多模型可参考 [PaddleOCR/models_list.md](https://github.com/PaddlePaddle/PaddleOCR/blob/release%2F2.5/doc/doc_ch/models_list.md) 。如有其他外语(如日、韩等)检测需求,可在 **知识星球** [**CnOCR/CnSTD私享群**](https://t.zsxq.com/FEYZRJQ) 中向作者提出建议。

## 使用方法

首次使用 **cnstd** 时,系统会自动从 [贝叶智能](https://www.behye.com) 的CDN上下载zip格式的模型压缩文件,并存放于 `~/.cnstd`目录(Windows下默认路径为 `C:\Users\<username>\AppData\Roaming\cnstd`)。下载速度超快。下载后的zip文件代码会自动对其解压,然后把解压后的模型相关目录放于`~/.cnstd/1.1`目录中。
首次使用 **cnstd** 时,系统会自动下载zip格式的模型压缩文件,并存放于 `~/.cnstd`目录(Windows下默认路径为 `C:\Users\<username>\AppData\Roaming\cnstd`)。下载速度超快。下载后的zip文件代码会自动对其解压,然后把解压后的模型相关目录放于`~/.cnstd/1.2`目录中。

如果系统无法自动成功下载zip文件,则需要手动从 [百度云盘](https://pan.baidu.com/s/11_83ydAwJ1u8RnyyZtBKjw)(提取码为 `56ji`)下载对应的zip文件并把它存放于 `~/.cnstd/1.1`(Windows下为 `C:\Users\<username>\AppData\Roaming\cnstd\1.1`)目录中。模型也可从 **[cnstd-cnocr-models](https://github.com/breezedeus/cnstd-cnocr-models)** 中下载。放置好zip文件后,后面的事代码就会自动执行了。
如果系统无法自动成功下载zip文件,则需要手动从 [百度云盘](https://pan.baidu.com/s/1zDMzArCDrrXHWL0AWxwYQQ?pwd=nstd)(提取码为 `nstd`)下载对应的zip文件并把它存放于 `~/.cnstd/1.2`(Windows下为 `C:\Users\<username>\AppData\Roaming\cnstd\1.2`)目录中。模型也可从 **[cnstd-cnocr-models](https://huggingface.co/breezedeus/cnstd-cnocr-models)** 中下载。放置好zip文件后,后面的事代码就会自动执行了。

### 图片预测

Expand All @@ -81,20 +95,23 @@ class CnStd(object):

def __init__(
self,
model_name: str = 'db_shufflenet_v2_small',
model_name: str = 'ch_PP-OCRv3_det',
*,
auto_rotate_whole_image: bool = False,
rotated_bbox: bool = True,
context: str = 'cpu',
model_fp: Optional[str] = None,
model_backend: str = 'onnx', # ['pytorch', 'onnx']
root: Union[str, Path] = data_dir(),
use_angle_clf: bool = False,
angle_clf_configs: Optional[dict] = None,
**kwargs,
):
```

其中的几个参数含义如下:

* `model_name`: 模型名称,即上面表格第一列中的值。默认为 **db_shufflenet_v2_small**
* `model_name`: 模型名称,即前面模型表格第一列中的值。默认为 **ch_PP-OCRv3_det**

* `auto_rotate_whole_image`: 是否自动对整张图片进行旋转调整。默认为`False`

Expand All @@ -104,11 +121,20 @@ class CnStd(object):

* `model_fp`: 如果不使用系统自带的模型,可以通过此参数直接指定所使用的模型文件(`.ckpt`文件)。

* `model_backend` (str): 'pytorch', or 'onnx'。表明预测时是使用 PyTorch 版本模型,还是使用 ONNX 版本模型。 同样的模型,ONNX 版本的预测速度一般是 PyTorch 版本的2倍左右。默认为 `onnx`

* `root`: 模型文件所在的根目录。

* Linux/Mac下默认值为 `~/.cnstd`,表示模型文件所处文件夹类似 `~/.cnstd/1.1/db_shufflenet_v2_small`
* Linux/Mac下默认值为 `~/.cnstd`,表示模型文件所处文件夹类似 `~/.cnstd/1.2/db_shufflenet_v2_small`
* Windows下默认值为 `C:\Users\<username>\AppData\Roaming\cnstd`

* `use_angle_clf` (bool): 对于检测出的文本框,是否使用角度分类模型进行调整(检测出的文本框可能会存在倒转180度的情况)。默认为 `False`

* `angle_clf_configs` (dict): 角度分类模型对应的参数取值,主要包含以下值:

- `model_name`: 模型名称。默认为 'ch_ppocr_mobile_v2.0_cls'
- `model_fp`: 如果不使用系统自带的模型,可以通过此参数直接指定所使用的模型文件('.onnx' 文件)。默认为 `None`。具体可参考类 `AngleClassifier` 的说明

每个参数都有默认取值,所以可以不传入任何参数值进行初始化:`std = CnStd()`

文本检测使用类`CnOcr`的函数 **`detect()`**,以下是详细说明:
Expand All @@ -125,7 +151,7 @@ class CnStd(object):
np.ndarray,
List[Union[str, Path, Image.Image, np.ndarray]],
],
resized_shape: Tuple[int, int] = (768, 768),
resized_shape: Union[int, Tuple[int, int]] = (768, 768),
preserve_aspect_ratio: bool = True,
min_box_size: int = 8,
box_score_thresh: float = 0.3,
Expand All @@ -140,9 +166,11 @@ class CnStd(object):

- `img_list`: 支持对单个图片或者多个图片(列表)的检测。每个值可以是图片路径,或者已经读取进来 `PIL.Image.Image``np.ndarray`, 格式应该是 `RGB` 3 通道,shape: `(height, width, 3)`, 取值范围:`[0, 255]`

- `resized_shape`: `(height, width)`, 检测前,会先把原始图片 resize 到此大小。默认为 `(768, 768)`
- `resized_shape`: `int` or `tuple`, `tuple` 含义为 `(height, width)`, `int` 则表示高宽都为此值;
检测前,先把原始图片resize到接近此大小(只是接近,未必相等)。默认为 `(768, 768)`

注:其中取值必须都能整除 `32`。这个取值对检测结果的影响较大,可以针对自己的应用多尝试几组值,再选出最优值。例如 `(512, 768)`, `(768, 768)`, `(768, 1024)`等。
> **Note** **(注意)**
> 这个取值对检测结果的影响较大,可以针对自己的应用多尝试几组值,再选出最优值。例如 `(512, 768)`, `(768, 768)`, `(768, 1024)`等。
- `preserve_aspect_ratio`: 对原始图片 resize 时是否保持高宽比不变。默认为 `True`

Expand All @@ -160,14 +188,11 @@ class CnStd(object):

- `detected_texts`: `list`, 每个元素存储了检测出的一个框的信息,使用词典记录,包括以下几个值:

- `box`:检测出的文字对应的矩形框;4个 (`rotated_bbox==False`) 或者 5个 (`rotated_bbox==True`) 元素;

- 4个元素时的含义:对应 `rotated_bbox==False`,取值为:`[xmin, ymin, xmax, ymax]` ;
- 5个元素时的含义:对应 `rotated_bbox==True`,取值为:`[x, y, w, h, angle]`
- `box`:检测出的文字对应的矩形框;`np.ndarray`, shape: `(4, 2)`,对应 box 4个点的坐标值 `(x, y)`;

- "score":得分;`float` 类型;分数越高表示越可靠;
- `score`:得分;`float` 类型;分数越高表示越可靠;

- "croppped_img":对应 "box" 中的图片patch(`RGB`格式),会把倾斜的图片旋转为水平。`np.ndarray`类型,`shape: (height, width, 3)`, 取值范围:`[0, 255]`
- `croppped_img`:对应 "box" 中的图片patch(`RGB`格式),会把倾斜的图片旋转为水平。`np.ndarray`类型,`shape: (height, width, 3)`, 取值范围:`[0, 255]`

- 示例:

Expand Down Expand Up @@ -248,10 +273,11 @@ Usage: cnstd predict [OPTIONS]
预测单个文件,或者指定目录下的所有图片

Options:
-m, --model-name [db_resnet50|db_resnet34|db_resnet18|db_mobilenet_v3|db_mobilenet_v3_small|db_shufflenet_v2|db_shufflenet_v2_small|db_shufflenet_v2_tiny]
模型名称。默认值为 `db_shufflenet_v2_small`
--model-epoch INTEGER model epoch。默认为 `None`,表示使用系统自带的预训练模型
-p, --pretrained-model-fp TEXT 导入的训练好的模型,作为初始模型。默认为 `None`,表示使用系统自带的预训练模型
-m, --model-name [ch_PP-OCRv2_det|ch_PP-OCRv3_det|db_mobilenet_v3|db_mobilenet_v3_small|db_resnet18|db_resnet34|db_shufflenet_v2|db_shufflenet_v2_small|db_shufflenet_v2_tiny|en_PP-OCRv3_det]
模型名称。默认值为 db_shufflenet_v2_small
-b, --model-backend [pytorch|onnx]
模型类型。默认值为 `onnx`
-p, --pretrained-model-fp TEXT 使用训练好的模型。默认为 `None`,表示使用系统自带的预训练模型
-r, --rotated-bbox 是否检测带角度(非水平和垂直)的文本框。默认为 `True`
--resized-shape TEXT 格式:"height,width";
预测时把图片resize到此大小再进行预测。两个值都需要是32的倍数。默认为
Expand All @@ -260,7 +286,9 @@ Options:
--box-score-thresh FLOAT 检测结果只保留分数大于此值的文本框。默认值为 `0.3`
--preserve-aspect-ratio BOOLEAN
resize时是否保留图片原始比例。默认值为 `True`
--context TEXT 使用cpu还是 `gpu` 运行代码,也可指定为特定gpu,如`cuda:0`。默认为 `cpu`
--context TEXT 使用cpu还是 `gpu` 运行代码,也可指定为特定gpu,如`cuda:0`。默认为
`cpu`

-i, --img-file-or-dir TEXT 输入图片的文件路径或者指定的文件夹
-o, --output-dir TEXT 检测结果存放的文件夹。默认为 `./predictions`
-h, --help Show this message and exit.
Expand Down Expand Up @@ -322,4 +350,5 @@ Options:
* [x] 进一步精简模型结构,降低模型大小。
* [x] PSENet速度上还是比较慢,尝试更快的STD算法。
* [x] 加入更多的训练数据。
* [x] 加入对外部模型的支持。
* [ ] 加入对文档结构与表格的检测
7 changes: 7 additions & 0 deletions RELEASE.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,13 @@
# Release Notes


# Update 2022.07.07:发布 cnstd V1.2

主要变更:
* 加入了对 [**PaddleOCR**](https://github.com/PaddlePaddle/PaddleOCR) 检测模型的支持;
* 部分调整了检测结果中 `box` 的表达方式,统一为 `4` 个点的坐标值;
* 修复了已知bugs。


# Update 2022.05.27:发布 cnstd V1.1.2

Expand Down
2 changes: 1 addition & 1 deletion cnstd/__version__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,4 @@
# specific language governing permissions and limitations
# under the License.

__version__ = '1.1.2'
__version__ = '1.2'
52 changes: 29 additions & 23 deletions cnstd/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,22 +31,28 @@
from cnocr import CnOcr
from cnocr.consts import AVAILABLE_MODELS

cnocr_avalable = True
except ModuleNotFoundError:
cnocr_avalable = False
cnocr_available = True
except Exception:
cnocr_available = False


@st.cache(allow_output_mutation=True)
def get_ocr_model(ocr_model_name):
if not cnocr_avalable:
if not cnocr_available:
return None
model_name, model_backend = ocr_model_name
return CnOcr(model_name, model_backend=model_backend)


@st.cache(allow_output_mutation=True)
def get_std_model(std_model_name, rotated_bbox):
return CnStd(std_model_name, rotated_bbox=rotated_bbox,)
def get_std_model(std_model_name, rotated_bbox, use_angle_clf):
model_name, model_backend = std_model_name
return CnStd(
model_name,
model_backend=model_backend,
rotated_bbox=rotated_bbox,
use_angle_clf=use_angle_clf,
)


def visualize_std(img, std_out, box_score_thresh):
Expand Down Expand Up @@ -79,26 +85,25 @@ def visualize_ocr(ocr, std_out):

def main():
st.sidebar.header('CnStd 设置')
models = list(STD_MODELS.keys())
models = list(STD_MODELS.all_models())
models.sort()
std_model_name = st.sidebar.selectbox(
'模型名称', models, index=models.index('db_shufflenet_v2_small')
'模型名称', models, index=models.index(('ch_PP-OCRv3_det', 'onnx'))
)
rotated_bbox = st.sidebar.checkbox('是否检测带角度文本框', value=True)
st.sidebar.subheader('resize 后图片大小')
height = st.sidebar.select_slider(
'height', options=[384, 512, 768, 896, 1024], value=768
)
width = st.sidebar.select_slider(
'width', options=[384, 512, 768, 896, 1024], value=768
)
preserve_aspect_ratio = st.sidebar.checkbox('resize 时是否等比例缩放', value=True)
st.sidebar.subheader('检测分数阈值')
use_angle_clf = st.sidebar.checkbox('是否使用角度预测模型校正文本框', value=False)
st.sidebar.subheader('resize 后图片(长边)大小')
new_size = st.sidebar.slider('高宽尺寸', min_value=124, max_value=4096, value=768)
st.sidebar.subheader('检测参数')
box_score_thresh = st.sidebar.slider(
'(低于阈值的结果会被过滤掉)', min_value=0.05, max_value=0.95, value=0.3
'得分阈值(低于阈值的结果会被过滤掉)', min_value=0.05, max_value=0.95, value=0.3
)
min_box_size = st.sidebar.slider(
'框大小阈值(更小的文本框会被过滤掉)', min_value=4, max_value=50, value=10
)
std = get_std_model(std_model_name, rotated_bbox)
std = get_std_model(std_model_name, rotated_bbox, use_angle_clf)

if cnocr_avalable:
if cnocr_available:
st.sidebar.markdown("""---""")
st.sidebar.header('CnOcr 设置')
all_models = list(AVAILABLE_MODELS.all_models())
Expand All @@ -121,13 +126,14 @@ def main():

std_out = std.detect(
img,
resized_shape=(height, width),
preserve_aspect_ratio=preserve_aspect_ratio,
resized_shape=new_size,
preserve_aspect_ratio=True,
box_score_thresh=box_score_thresh,
min_box_size=min_box_size,
)
visualize_std(img, std_out, box_score_thresh)

if cnocr_avalable:
if cnocr_available:
visualize_ocr(ocr, std_out)
except Exception as e:
st.error(e)
Expand Down
20 changes: 17 additions & 3 deletions cnstd/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,12 @@
import time
import glob

from pprint import pformat
import numpy as np
import torchvision.transforms as T

from .utils import rotate_page
from .consts import MODEL_VERSION, MODEL_CONFIGS
from .consts import MODEL_VERSION, MODEL_CONFIGS, AVAILABLE_MODELS
from .utils import (
set_logger,
data_dir,
Expand Down Expand Up @@ -168,14 +169,25 @@ def _vis_bool(img, fp):
)


MODELS, _ = zip(*AVAILABLE_MODELS.all_models())
MODELS = sorted(MODELS)


@cli.command('predict')
@click.option(
'-m',
'--model-name',
type=click.Choice(MODEL_CONFIGS.keys()),
type=click.Choice(MODELS),
default=DEFAULT_MODEL_NAME,
help='模型名称。默认值为 %s' % DEFAULT_MODEL_NAME,
)
@click.option(
'-b',
'--model-backend',
type=click.Choice(['pytorch', 'onnx']),
default='onnx',
help='模型类型。默认值为 `onnx`',
)
@click.option(
'-p',
'--pretrained-model-fp',
Expand Down Expand Up @@ -213,6 +225,7 @@ def _vis_bool(img, fp):
)
def predict(
model_name,
model_backend,
pretrained_model_fp,
rotated_bbox,
resized_shape,
Expand All @@ -225,6 +238,7 @@ def predict(
"""预测单个文件,或者指定目录下的所有图片"""
std = CnStd(
model_name,
model_backend=model_backend,
model_fp=pretrained_model_fp,
rotated_bbox=rotated_bbox,
context=context,
Expand Down Expand Up @@ -288,7 +302,7 @@ def predict(
'%d cropped text boxes are recognized by cnocr, total time cost: %f, mean time cost: %f'
% (len(cropped_img_list), time_cost, time_cost / len(cropped_img_list))
)
logger.info('ocr result: %s' % str(ocr_out))
logger.info('ocr result: \n%s' % pformat(ocr_out))
except ModuleNotFoundError as e:
logger.warning(e)

Expand Down
Loading

0 comments on commit 15fa3d1

Please sign in to comment.