From 6c26e887d4300c3afb57a0900c6ad53798d9dfda Mon Sep 17 00:00:00 2001 From: AllentDan <41138331+AllentDan@users.noreply.github.com> Date: Wed, 12 Apr 2023 15:43:51 +0800 Subject: [PATCH] [Doc] update SAR status (#1789) * update SAR status and fix torchscript export * add reminder for SAR --- docs/en/04-supported-codebases/mmocr.md | 3 +++ docs/zh_cn/04-supported-codebases/mmocr.md | 2 ++ .../codebase/mmocr/models/text_recognition/sar_decoder.py | 4 ++-- .../codebase/mmocr/models/text_recognition/sar_encoder.py | 2 +- tests/regression/mmocr.yml | 1 - 5 files changed, 8 insertions(+), 4 deletions(-) diff --git a/docs/en/04-supported-codebases/mmocr.md b/docs/en/04-supported-codebases/mmocr.md index e964ab740d..fffe8eca94 100644 --- a/docs/en/04-supported-codebases/mmocr.md +++ b/docs/en/04-supported-codebases/mmocr.md @@ -251,6 +251,9 @@ Besides python API, mmdeploy SDK also provides other FFI (Foreign Function Inter - ABINet for TensorRT require pytorch1.10+ and TensorRT 8.4+. +- SAR uses `valid_ratio` inside network inference, which causes performance drops. When the `valid_ratio`s between + testing image and the image for conversion are quite different, the gap would be enlarged. + - For TensorRT backend, users have to choose the right config. For example, CRNN only accepts 1 channel input. Here is a recommendation table: | Model | Config | diff --git a/docs/zh_cn/04-supported-codebases/mmocr.md b/docs/zh_cn/04-supported-codebases/mmocr.md index 8dd1798ee8..efdeea6416 100644 --- a/docs/zh_cn/04-supported-codebases/mmocr.md +++ b/docs/zh_cn/04-supported-codebases/mmocr.md @@ -255,6 +255,8 @@ print(texts) - ABINet 在 TensorRT 后端要求使用 pytorch1.10+, TensorRT 8.4+。 +- SAR 在网络推广中使用 `valid_ratio`,这会让导出的 ONNX 文件精度下降。当测试图片的 `valid_ratio`s 和转换图片的值差异很大,这种下降就会越多。 + - 对于 TensorRT 后端,用户需要使用正确的配置文件。比如 CRNN 只接受单通道输入。下面是一个示例表格: | Model | Config | diff --git a/mmdeploy/codebase/mmocr/models/text_recognition/sar_decoder.py b/mmdeploy/codebase/mmocr/models/text_recognition/sar_decoder.py index 38b51e0929..5e40dff9af 100644 --- a/mmdeploy/codebase/mmocr/models/text_recognition/sar_decoder.py +++ b/mmdeploy/codebase/mmocr/models/text_recognition/sar_decoder.py @@ -52,7 +52,7 @@ def parallel_sar_decoder__2d_attention( attn_mask = torch.zeros(bsz, T, h, w + 1, c).to(attn_weight.device) for i, valid_ratio in enumerate(valid_ratios): # use torch.ceil to replace original math.ceil and if else in mmocr - valid_width = torch.ceil(w * valid_ratio).long() + valid_width = torch.tensor(w * valid_ratio).ceil().long() # use narrow to replace original [valid_width:] in mmocr attn_mask[i].narrow(2, valid_width, w + 1 - valid_width)[:] = 1 attn_mask = attn_mask[:, :, :, :w, :] @@ -123,7 +123,7 @@ def sequential_sar_decoder__2d_attention(self, attn_mask = torch.zeros(bsz, c, h, w + 1).to(attn_weight.device) for i, valid_ratio in enumerate(valid_ratios): # use torch.ceil to replace original math.ceil and if else in mmocr - valid_width = torch.ceil(w * valid_ratio).long() + valid_width = torch.tensor(w * valid_ratio).ceil().long() # use narrow to replace original [valid_width:] in mmocr attn_mask[i].narrow(2, valid_width, w + 1 - valid_width)[:] = 1 attn_mask = attn_mask[:, :, :, :w] diff --git a/mmdeploy/codebase/mmocr/models/text_recognition/sar_encoder.py b/mmdeploy/codebase/mmocr/models/text_recognition/sar_encoder.py index dc5a87f6f1..b159ba7b7d 100644 --- a/mmdeploy/codebase/mmocr/models/text_recognition/sar_encoder.py +++ b/mmdeploy/codebase/mmocr/models/text_recognition/sar_encoder.py @@ -54,7 +54,7 @@ def sar_encoder__forward( T = holistic_feat.size(1) for i, valid_ratio in enumerate(valid_ratios): # use torch.ceil to replace original math.ceil and if else in mmocr - valid_step = torch.ceil(T * valid_ratio).long() - 1 + valid_step = torch.tensor(T * valid_ratio).ceil().long() - 1 valid_hf.append(holistic_feat[i, valid_step, :]) valid_hf = torch.stack(valid_hf, dim=0) else: diff --git a/tests/regression/mmocr.yml b/tests/regression/mmocr.yml index 16ce8c5a18..2270f66ec1 100644 --- a/tests/regression/mmocr.yml +++ b/tests/regression/mmocr.yml @@ -307,7 +307,6 @@ models: pipelines: - *pipeline_ts_recognition_fp32 - *pipeline_ort_recognition_dynamic_fp32 - - *pipeline_trt_recognition_dynamic_fp32_H48_C3 - name: SATRN metafile: configs/textrecog/satrn/metafile.yml