From c1fd46641eff86b8e7ee2aa8d1943f7aced5b3a7 Mon Sep 17 00:00:00 2001 From: tink2123 Date: Wed, 30 Dec 2020 16:15:49 +0800 Subject: [PATCH 01/10] add srn for dygraph --- configs/rec/rec_mv3_none_bilstm_ctc.yml | 6 +- configs/rec/rec_mv3_none_none_ctc.yml | 4 +- configs/rec/rec_mv3_tps_bilstm_ctc.yml | 4 +- configs/rec/rec_r34_vd_none_bilstm_ctc.yml | 4 +- configs/rec/rec_r34_vd_none_none_ctc.yml | 4 +- configs/rec/rec_r34_vd_tps_bilstm_ctc.yml | 4 +- configs/rec/rec_r50_fpn_srn.yml | 106 ++++++ ppocr/data/__init__.py | 4 +- ppocr/data/imaug/__init__.py | 2 +- ppocr/data/imaug/label_ops.py | 48 +++ ppocr/data/imaug/rec_img_aug.py | 90 ++++- ppocr/data/lmdb_dataset.py | 4 +- ppocr/losses/__init__.py | 5 +- ppocr/losses/rec_srn_loss.py | 47 +++ ppocr/metrics/__init__.py | 1 + ppocr/metrics/rec_metric.py | 4 +- ppocr/modeling/architectures/base_model.py | 7 +- ppocr/modeling/backbones/__init__.py | 3 +- ppocr/modeling/backbones/rec_resnet_fpn.py | 307 ++++++++++++++++ ppocr/modeling/heads/__init__.py | 5 +- ppocr/modeling/heads/rec_srn_head.py | 279 ++++++++++++++ ppocr/modeling/heads/self_attention.py | 408 +++++++++++++++++++++ ppocr/postprocess/__init__.py | 5 +- ppocr/postprocess/rec_postprocess.py | 84 ++++- tools/export_model.py | 41 ++- tools/infer/predict_rec.py | 149 +++++++- tools/infer_rec.py | 25 +- tools/program.py | 14 +- 28 files changed, 1594 insertions(+), 70 deletions(-) create mode 100644 configs/rec/rec_r50_fpn_srn.yml create mode 100644 ppocr/losses/rec_srn_loss.py create mode 100644 ppocr/modeling/backbones/rec_resnet_fpn.py create mode 100644 ppocr/modeling/heads/rec_srn_head.py create mode 100644 ppocr/modeling/heads/self_attention.py diff --git a/configs/rec/rec_mv3_none_bilstm_ctc.yml b/configs/rec/rec_mv3_none_bilstm_ctc.yml index 38f1e869..00c1db88 100644 --- a/configs/rec/rec_mv3_none_bilstm_ctc.yml +++ b/configs/rec/rec_mv3_none_bilstm_ctc.yml @@ -1,5 +1,5 @@ Global: - use_gpu: true + use_gpu: True epoch_num: 72 log_smooth_window: 20 print_batch_step: 10 @@ -59,7 +59,7 @@ Metric: Train: dataset: - name: LMDBDateSet + name: LMDBDataSet data_dir: ./train_data/data_lmdb_release/training/ transforms: - DecodeImage: # load image @@ -78,7 +78,7 @@ Train: Eval: dataset: - name: LMDBDateSet + name: LMDBDataSet data_dir: ./train_data/data_lmdb_release/validation/ transforms: - DecodeImage: # load image diff --git a/configs/rec/rec_mv3_none_none_ctc.yml b/configs/rec/rec_mv3_none_none_ctc.yml index 33079ad4..6711b1d2 100644 --- a/configs/rec/rec_mv3_none_none_ctc.yml +++ b/configs/rec/rec_mv3_none_none_ctc.yml @@ -58,7 +58,7 @@ Metric: Train: dataset: - name: LMDBDateSet + name: LMDBDataSet data_dir: ./train_data/data_lmdb_release/training/ transforms: - DecodeImage: # load image @@ -77,7 +77,7 @@ Train: Eval: dataset: - name: LMDBDateSet + name: LMDBDataSet data_dir: ./train_data/data_lmdb_release/validation/ transforms: - DecodeImage: # load image diff --git a/configs/rec/rec_mv3_tps_bilstm_ctc.yml b/configs/rec/rec_mv3_tps_bilstm_ctc.yml index 08f68939..1b9fb0a0 100644 --- a/configs/rec/rec_mv3_tps_bilstm_ctc.yml +++ b/configs/rec/rec_mv3_tps_bilstm_ctc.yml @@ -63,7 +63,7 @@ Metric: Train: dataset: - name: LMDBDateSet + name: LMDBDataSet data_dir: ./train_data/data_lmdb_release/training/ transforms: - DecodeImage: # load image @@ -82,7 +82,7 @@ Train: Eval: dataset: - name: LMDBDateSet + name: LMDBDataSet data_dir: ./train_data/data_lmdb_release/validation/ transforms: - DecodeImage: # load image diff --git a/configs/rec/rec_r34_vd_none_bilstm_ctc.yml b/configs/rec/rec_r34_vd_none_bilstm_ctc.yml index 4ad2ff89..e4d301a6 100644 --- a/configs/rec/rec_r34_vd_none_bilstm_ctc.yml +++ b/configs/rec/rec_r34_vd_none_bilstm_ctc.yml @@ -58,7 +58,7 @@ Metric: Train: dataset: - name: LMDBDateSet + name: LMDBDataSet data_dir: ./train_data/data_lmdb_release/training/ transforms: - DecodeImage: # load image @@ -77,7 +77,7 @@ Train: Eval: dataset: - name: LMDBDateSet + name: LMDBDataSet data_dir: ./train_data/data_lmdb_release/validation/ transforms: - DecodeImage: # load image diff --git a/configs/rec/rec_r34_vd_none_none_ctc.yml b/configs/rec/rec_r34_vd_none_none_ctc.yml index 9c1eeb30..4a17a004 100644 --- a/configs/rec/rec_r34_vd_none_none_ctc.yml +++ b/configs/rec/rec_r34_vd_none_none_ctc.yml @@ -56,7 +56,7 @@ Metric: Train: dataset: - name: LMDBDateSet + name: LMDBDataSet data_dir: ./train_data/data_lmdb_release/training/ transforms: - DecodeImage: # load image @@ -75,7 +75,7 @@ Train: Eval: dataset: - name: LMDBDateSet + name: LMDBDataSet data_dir: ./train_data/data_lmdb_release/validation/ transforms: - DecodeImage: # load image diff --git a/configs/rec/rec_r34_vd_tps_bilstm_ctc.yml b/configs/rec/rec_r34_vd_tps_bilstm_ctc.yml index aeded492..62edf843 100644 --- a/configs/rec/rec_r34_vd_tps_bilstm_ctc.yml +++ b/configs/rec/rec_r34_vd_tps_bilstm_ctc.yml @@ -62,7 +62,7 @@ Metric: Train: dataset: - name: LMDBDateSet + name: LMDBDataSet data_dir: ./train_data/data_lmdb_release/training/ transforms: - DecodeImage: # load image @@ -81,7 +81,7 @@ Train: Eval: dataset: - name: LMDBDateSet + name: LMDBDataSet data_dir: ./train_data/data_lmdb_release/validation/ transforms: - DecodeImage: # load image diff --git a/configs/rec/rec_r50_fpn_srn.yml b/configs/rec/rec_r50_fpn_srn.yml new file mode 100644 index 00000000..78f8d551 --- /dev/null +++ b/configs/rec/rec_r50_fpn_srn.yml @@ -0,0 +1,106 @@ +Global: + use_gpu: True + epoch_num: 72 + log_smooth_window: 20 + print_batch_step: 5 + save_model_dir: ./output/rec/srn + save_epoch_step: 3 + # evaluation is run every 5000 iterations after the 4000th iteration + eval_batch_step: [0, 5000] + # if pretrained_model is saved in static mode, load_static_weights must set to True + cal_metric_during_train: True + pretrained_model: + checkpoints: + save_inference_dir: + use_visualdl: False + infer_img: doc/imgs_words/ch/word_1.jpg + # for data or label process + character_dict_path: + character_type: en + max_text_length: 25 + num_heads: 8 + infer_mode: False + use_space_char: False + + +Optimizer: + name: Adam + lr: + name: Cosine + learning_rate: 0.0001 + +Architecture: + model_type: rec + algorithm: SRN + in_channels: 1 + Transform: + Backbone: + name: ResNetFPN + Head: + name: SRNHead + max_text_length: 25 + num_heads: 8 + num_encoder_TUs: 2 + num_decoder_TUs: 4 + hidden_dims: 512 + +Loss: + name: SRNLoss + +PostProcess: + name: SRNLabelDecode + +Metric: + name: RecMetric + main_indicator: acc + +Train: + dataset: + name: LMDBDataSet + data_dir: ./train_data/srn_train_data_duiqi + #label_file_list: ["./train_data/ic15_data/1.txt"] + transforms: + - DecodeImage: # load image + img_mode: BGR + channel_first: False + - SRNLabelEncode: # Class handling label + - SRNRecResizeImg: + image_shape: [1, 64, 256] + - KeepKeys: + keep_keys: ['image', + 'label', + 'length', + 'encoder_word_pos', + 'gsrm_word_pos', + 'gsrm_slf_attn_bias1', + 'gsrm_slf_attn_bias2'] # dataloader will return list in this order + loader: + shuffle: False + batch_size_per_card: 64 + drop_last: True + num_workers: 4 + +Eval: + dataset: + name: LMDBDataSet + data_dir: ./train_data/data_lmdb_release/evaluation + transforms: + - DecodeImage: # load image + img_mode: BGR + channel_first: False + - SRNLabelEncode: # Class handling label + - SRNRecResizeImg: + image_shape: [1, 64, 256] + - KeepKeys: + keep_keys: ['image', + 'label', + 'length', + 'encoder_word_pos', + 'gsrm_word_pos', + 'gsrm_slf_attn_bias1', + 'gsrm_slf_attn_bias2'] + loader: + shuffle: False + drop_last: False + batch_size_per_card: 32 + num_workers: 4 diff --git a/ppocr/data/__init__.py b/ppocr/data/__init__.py index 7b0faf12..eb461ffa 100644 --- a/ppocr/data/__init__.py +++ b/ppocr/data/__init__.py @@ -33,7 +33,7 @@ import paddle.distributed as dist from ppocr.data.imaug import transform, create_operators from ppocr.data.simple_dataset import SimpleDataSet -from ppocr.data.lmdb_dataset import LMDBDateSet +from ppocr.data.lmdb_dataset import LMDBDataSet __all__ = ['build_dataloader', 'transform', 'create_operators'] @@ -54,7 +54,7 @@ signal.signal(signal.SIGTERM, term_mp) def build_dataloader(config, mode, device, logger): config = copy.deepcopy(config) - support_dict = ['SimpleDataSet', 'LMDBDateSet'] + support_dict = ['SimpleDataSet', 'LMDBDataSet'] module_name = config[mode]['dataset']['name'] assert module_name in support_dict, Exception( 'DataSet only support {}'.format(support_dict)) diff --git a/ppocr/data/imaug/__init__.py b/ppocr/data/imaug/__init__.py index 6ea4dd8e..250ac75e 100644 --- a/ppocr/data/imaug/__init__.py +++ b/ppocr/data/imaug/__init__.py @@ -21,7 +21,7 @@ from .make_border_map import MakeBorderMap from .make_shrink_map import MakeShrinkMap from .random_crop_data import EastRandomCropData, PSERandomCrop -from .rec_img_aug import RecAug, RecResizeImg, ClsResizeImg +from .rec_img_aug import RecAug, RecResizeImg, ClsResizeImg, SRNRecResizeImg from .randaugment import RandAugment from .operators import * from .label_ops import * diff --git a/ppocr/data/imaug/label_ops.py b/ppocr/data/imaug/label_ops.py index af3308a5..986cec3d 100644 --- a/ppocr/data/imaug/label_ops.py +++ b/ppocr/data/imaug/label_ops.py @@ -98,6 +98,8 @@ class BaseRecLabelEncode(object): support_character_type, character_type) self.max_text_len = max_text_length + self.beg_str = "sos" + self.end_str = "eos" if character_type == "en": self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz" dict_character = list(self.character_str) @@ -213,3 +215,49 @@ class AttnLabelEncode(BaseRecLabelEncode): assert False, "Unsupport type %s in get_beg_end_flag_idx" \ % beg_or_end return idx + + +class SRNLabelEncode(BaseRecLabelEncode): + """ Convert between text-label and text-index """ + + def __init__(self, + max_text_length=25, + character_dict_path=None, + character_type='en', + use_space_char=False, + **kwargs): + super(SRNLabelEncode, + self).__init__(max_text_length, character_dict_path, + character_type, use_space_char) + + def add_special_char(self, dict_character): + dict_character = dict_character + [self.beg_str, self.end_str] + return dict_character + + def __call__(self, data): + text = data['label'] + text = self.encode(text) + char_num = len(self.character_str) + if text is None: + return None + if len(text) > self.max_text_len: + return None + data['length'] = np.array(len(text)) + text = text + [char_num] * (self.max_text_len - len(text)) + data['label'] = np.array(text) + return data + + def get_ignored_tokens(self): + beg_idx = self.get_beg_end_flag_idx("beg") + end_idx = self.get_beg_end_flag_idx("end") + return [beg_idx, end_idx] + + def get_beg_end_flag_idx(self, beg_or_end): + if beg_or_end == "beg": + idx = np.array(self.dict[self.beg_str]) + elif beg_or_end == "end": + idx = np.array(self.dict[self.end_str]) + else: + assert False, "Unsupport type %s in get_beg_end_flag_idx" \ + % beg_or_end + return idx diff --git a/ppocr/data/imaug/rec_img_aug.py b/ppocr/data/imaug/rec_img_aug.py index 2ccb2d1d..28e6bd0b 100644 --- a/ppocr/data/imaug/rec_img_aug.py +++ b/ppocr/data/imaug/rec_img_aug.py @@ -12,20 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - import math import cv2 import numpy as np @@ -77,6 +63,26 @@ class RecResizeImg(object): return data +class SRNRecResizeImg(object): + def __init__(self, image_shape, num_heads, max_text_length, **kwargs): + self.image_shape = image_shape + self.num_heads = num_heads + self.max_text_length = max_text_length + + def __call__(self, data): + img = data['image'] + norm_img = resize_norm_img_srn(img, self.image_shape) + data['image'] = norm_img + [encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1, gsrm_slf_attn_bias2] = \ + srn_other_inputs(self.image_shape, self.num_heads, self.max_text_length) + + data['encoder_word_pos'] = encoder_word_pos + data['gsrm_word_pos'] = gsrm_word_pos + data['gsrm_slf_attn_bias1'] = gsrm_slf_attn_bias1 + data['gsrm_slf_attn_bias2'] = gsrm_slf_attn_bias2 + return data + + def resize_norm_img(img, image_shape): imgC, imgH, imgW = image_shape h = img.shape[0] @@ -103,7 +109,7 @@ def resize_norm_img(img, image_shape): def resize_norm_img_chinese(img, image_shape): imgC, imgH, imgW = image_shape # todo: change to 0 and modified image shape - max_wh_ratio = 0 + max_wh_ratio = imgW * 1.0 / imgH h, w = img.shape[0], img.shape[1] ratio = w * 1.0 / h max_wh_ratio = max(max_wh_ratio, ratio) @@ -126,6 +132,60 @@ def resize_norm_img_chinese(img, image_shape): return padding_im +def resize_norm_img_srn(img, image_shape): + imgC, imgH, imgW = image_shape + + img_black = np.zeros((imgH, imgW)) + im_hei = img.shape[0] + im_wid = img.shape[1] + + if im_wid <= im_hei * 1: + img_new = cv2.resize(img, (imgH * 1, imgH)) + elif im_wid <= im_hei * 2: + img_new = cv2.resize(img, (imgH * 2, imgH)) + elif im_wid <= im_hei * 3: + img_new = cv2.resize(img, (imgH * 3, imgH)) + else: + img_new = cv2.resize(img, (imgW, imgH)) + + img_np = np.asarray(img_new) + img_np = cv2.cvtColor(img_np, cv2.COLOR_BGR2GRAY) + img_black[:, 0:img_np.shape[1]] = img_np + img_black = img_black[:, :, np.newaxis] + + row, col, c = img_black.shape + c = 1 + + return np.reshape(img_black, (c, row, col)).astype(np.float32) + + +def srn_other_inputs(image_shape, num_heads, max_text_length): + + imgC, imgH, imgW = image_shape + feature_dim = int((imgH / 8) * (imgW / 8)) + + encoder_word_pos = np.array(range(0, feature_dim)).reshape( + (feature_dim, 1)).astype('int64') + gsrm_word_pos = np.array(range(0, max_text_length)).reshape( + (max_text_length, 1)).astype('int64') + + gsrm_attn_bias_data = np.ones((1, max_text_length, max_text_length)) + gsrm_slf_attn_bias1 = np.triu(gsrm_attn_bias_data, 1).reshape( + [1, max_text_length, max_text_length]) + gsrm_slf_attn_bias1 = np.tile(gsrm_slf_attn_bias1, + [num_heads, 1, 1]) * [-1e9] + + gsrm_slf_attn_bias2 = np.tril(gsrm_attn_bias_data, -1).reshape( + [1, max_text_length, max_text_length]) + gsrm_slf_attn_bias2 = np.tile(gsrm_slf_attn_bias2, + [num_heads, 1, 1]) * [-1e9] + + return [ + encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1, + gsrm_slf_attn_bias2 + ] + + def flag(): """ flag diff --git a/ppocr/data/lmdb_dataset.py b/ppocr/data/lmdb_dataset.py index e7bb6dd3..515279fb 100644 --- a/ppocr/data/lmdb_dataset.py +++ b/ppocr/data/lmdb_dataset.py @@ -20,9 +20,9 @@ import cv2 from .imaug import transform, create_operators -class LMDBDateSet(Dataset): +class LMDBDataSet(Dataset): def __init__(self, config, mode, logger): - super(LMDBDateSet, self).__init__() + super(LMDBDataSet, self).__init__() global_config = config['Global'] dataset_config = config[mode]['dataset'] diff --git a/ppocr/losses/__init__.py b/ppocr/losses/__init__.py index 4673d35c..b280eb33 100755 --- a/ppocr/losses/__init__.py +++ b/ppocr/losses/__init__.py @@ -23,11 +23,14 @@ def build_loss(config): # rec loss from .rec_ctc_loss import CTCLoss + from .rec_srn_loss import SRNLoss # cls loss from .cls_loss import ClsLoss - support_dict = ['DBLoss', 'EASTLoss', 'SASTLoss', 'CTCLoss', 'ClsLoss'] + support_dict = [ + 'DBLoss', 'EASTLoss', 'SASTLoss', 'CTCLoss', 'ClsLoss', 'SRNLoss' + ] config = copy.deepcopy(config) module_name = config.pop('name') diff --git a/ppocr/losses/rec_srn_loss.py b/ppocr/losses/rec_srn_loss.py new file mode 100644 index 00000000..d722ee0f --- /dev/null +++ b/ppocr/losses/rec_srn_loss.py @@ -0,0 +1,47 @@ +# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import paddle +from paddle import nn + + +class SRNLoss(nn.Layer): + def __init__(self, **kwargs): + super(SRNLoss, self).__init__() + self.loss_func = paddle.nn.loss.CrossEntropyLoss(reduction="sum") + + def forward(self, predicts, batch): + predict = predicts['predict'] + word_predict = predicts['word_out'] + gsrm_predict = predicts['gsrm_out'] + label = batch[1] + + casted_label = paddle.cast(x=label, dtype='int64') + casted_label = paddle.reshape(x=casted_label, shape=[-1, 1]) + + cost_word = self.loss_func(word_predict, label=casted_label) + cost_gsrm = self.loss_func(gsrm_predict, label=casted_label) + cost_vsfd = self.loss_func(predict, label=casted_label) + + cost_word = paddle.reshape(x=paddle.sum(cost_word), shape=[1]) + cost_gsrm = paddle.reshape(x=paddle.sum(cost_gsrm), shape=[1]) + cost_vsfd = paddle.reshape(x=paddle.sum(cost_vsfd), shape=[1]) + + sum_cost = cost_word + cost_vsfd * 2.0 + cost_gsrm * 0.15 + + return {'loss': sum_cost, 'word_loss': cost_word, 'img_loss': cost_vsfd} diff --git a/ppocr/metrics/__init__.py b/ppocr/metrics/__init__.py index a0e7d912..41828f51 100644 --- a/ppocr/metrics/__init__.py +++ b/ppocr/metrics/__init__.py @@ -26,6 +26,7 @@ def build_metric(config): from .det_metric import DetMetric from .rec_metric import RecMetric from .cls_metric import ClsMetric + from .rec_metric import RecMetric support_dict = ['DetMetric', 'RecMetric', 'ClsMetric'] diff --git a/ppocr/metrics/rec_metric.py b/ppocr/metrics/rec_metric.py index bd0f92e0..459fe8e4 100644 --- a/ppocr/metrics/rec_metric.py +++ b/ppocr/metrics/rec_metric.py @@ -31,8 +31,6 @@ class RecMetric(object): if pred == target: correct_num += 1 all_num += 1 - # if all_num < 10 and kwargs.get('show_str', False): - # print('{} -> {}'.format(pred, target)) self.correct_num += correct_num self.all_num += all_num self.norm_edit_dis += norm_edit_dis @@ -48,7 +46,7 @@ class RecMetric(object): 'norm_edit_dis': 0, } """ - acc = self.correct_num / self.all_num + acc = 1.0 * self.correct_num / self.all_num norm_edit_dis = 1 - self.norm_edit_dis / self.all_num self.reset() return {'acc': acc, 'norm_edit_dis': norm_edit_dis} diff --git a/ppocr/modeling/architectures/base_model.py b/ppocr/modeling/architectures/base_model.py index ab44b53a..09b6e034 100644 --- a/ppocr/modeling/architectures/base_model.py +++ b/ppocr/modeling/architectures/base_model.py @@ -68,11 +68,14 @@ class BaseModel(nn.Layer): config["Head"]['in_channels'] = in_channels self.head = build_head(config["Head"]) - def forward(self, x): + def forward(self, x, data=None): if self.use_transform: x = self.transform(x) x = self.backbone(x) if self.use_neck: x = self.neck(x) - x = self.head(x) + if data is None: + x = self.head(x) + else: + x = self.head(x, data) return x diff --git a/ppocr/modeling/backbones/__init__.py b/ppocr/modeling/backbones/__init__.py index 43103e53..03c15508 100755 --- a/ppocr/modeling/backbones/__init__.py +++ b/ppocr/modeling/backbones/__init__.py @@ -24,7 +24,8 @@ def build_backbone(config, model_type): elif model_type == 'rec' or model_type == 'cls': from .rec_mobilenet_v3 import MobileNetV3 from .rec_resnet_vd import ResNet - support_dict = ['MobileNetV3', 'ResNet', 'ResNet_FPN'] + from .rec_resnet_fpn import ResNetFPN + support_dict = ['MobileNetV3', 'ResNet', 'ResNetFPN'] else: raise NotImplementedError diff --git a/ppocr/modeling/backbones/rec_resnet_fpn.py b/ppocr/modeling/backbones/rec_resnet_fpn.py new file mode 100644 index 00000000..a7e876a2 --- /dev/null +++ b/ppocr/modeling/backbones/rec_resnet_fpn.py @@ -0,0 +1,307 @@ +#copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +#Licensed under the Apache License, Version 2.0 (the "License"); +#you may not use this file except in compliance with the License. +#You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +#Unless required by applicable law or agreed to in writing, software +#distributed under the License is distributed on an "AS IS" BASIS, +#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +#See the License for the specific language governing permissions and +#limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from paddle import nn, ParamAttr +from paddle.nn import functional as F +import paddle.fluid as fluid +import paddle +import numpy as np + +__all__ = ["ResNetFPN"] + + +class ResNetFPN(nn.Layer): + def __init__(self, in_channels=1, layers=50, **kwargs): + super(ResNetFPN, self).__init__() + supported_layers = { + 18: { + 'depth': [2, 2, 2, 2], + 'block_class': BasicBlock + }, + 34: { + 'depth': [3, 4, 6, 3], + 'block_class': BasicBlock + }, + 50: { + 'depth': [3, 4, 6, 3], + 'block_class': BottleneckBlock + }, + 101: { + 'depth': [3, 4, 23, 3], + 'block_class': BottleneckBlock + }, + 152: { + 'depth': [3, 8, 36, 3], + 'block_class': BottleneckBlock + } + } + stride_list = [(2, 2), (2, 2), (1, 1), (1, 1)] + num_filters = [64, 128, 256, 512] + self.depth = supported_layers[layers]['depth'] + self.F = [] + self.conv = ConvBNLayer( + in_channels=in_channels, + out_channels=64, + kernel_size=7, + stride=2, + act="relu", + name="conv1") + self.block_list = [] + in_ch = 64 + if layers >= 50: + for block in range(len(self.depth)): + for i in range(self.depth[block]): + if layers in [101, 152] and block == 2: + if i == 0: + conv_name = "res" + str(block + 2) + "a" + else: + conv_name = "res" + str(block + 2) + "b" + str(i) + else: + conv_name = "res" + str(block + 2) + chr(97 + i) + block_list = self.add_sublayer( + "bottleneckBlock_{}_{}".format(block, i), + BottleneckBlock( + in_channels=in_ch, + out_channels=num_filters[block], + stride=stride_list[block] if i == 0 else 1, + name=conv_name)) + in_ch = num_filters[block] * 4 + self.block_list.append(block_list) + self.F.append(block_list) + else: + for block in range(len(self.depth)): + for i in range(self.depth[block]): + conv_name = "res" + str(block + 2) + chr(97 + i) + if i == 0 and block != 0: + stride = (2, 1) + else: + stride = (1, 1) + basic_block = self.add_sublayer( + conv_name, + BasicBlock( + in_channels=in_ch, + out_channels=num_filters[block], + stride=stride_list[block] if i == 0 else 1, + is_first=block == i == 0, + name=conv_name)) + in_ch = basic_block.out_channels + self.block_list.append(basic_block) + out_ch_list = [in_ch // 4, in_ch // 2, in_ch] + self.base_block = [] + self.conv_trans = [] + self.bn_block = [] + for i in [-2, -3]: + in_channels = out_ch_list[i + 1] + out_ch_list[i] + + self.base_block.append( + self.add_sublayer( + "F_{}_base_block_0".format(i), + nn.Conv2D( + in_channels=in_channels, + out_channels=out_ch_list[i], + kernel_size=1, + weight_attr=ParamAttr(trainable=True), + bias_attr=ParamAttr(trainable=True)))) + self.base_block.append( + self.add_sublayer( + "F_{}_base_block_1".format(i), + nn.Conv2D( + in_channels=out_ch_list[i], + out_channels=out_ch_list[i], + kernel_size=3, + padding=1, + weight_attr=ParamAttr(trainable=True), + bias_attr=ParamAttr(trainable=True)))) + self.base_block.append( + self.add_sublayer( + "F_{}_base_block_2".format(i), + nn.BatchNorm( + num_channels=out_ch_list[i], + act="relu", + param_attr=ParamAttr(trainable=True), + bias_attr=ParamAttr(trainable=True)))) + self.base_block.append( + self.add_sublayer( + "F_{}_base_block_3".format(i), + nn.Conv2D( + in_channels=out_ch_list[i], + out_channels=512, + kernel_size=1, + bias_attr=ParamAttr(trainable=True), + weight_attr=ParamAttr(trainable=True)))) + self.out_channels = 512 + + def __call__(self, x): + x = self.conv(x) + fpn_list = [] + F = [] + for i in range(len(self.depth)): + fpn_list.append(np.sum(self.depth[:i + 1])) + + for i, block in enumerate(self.block_list): + x = block(x) + for number in fpn_list: + if i + 1 == number: + F.append(x) + base = F[-1] + + j = 0 + for i, block in enumerate(self.base_block): + if i % 3 == 0 and i < 6: + j = j + 1 + b, c, w, h = F[-j - 1].shape + if [w, h] == list(base.shape[2:]): + base = base + else: + base = self.conv_trans[j - 1](base) + base = self.bn_block[j - 1](base) + base = paddle.concat([base, F[-j - 1]], axis=1) + base = block(base) + return base + + +class ConvBNLayer(nn.Layer): + def __init__(self, + in_channels, + out_channels, + kernel_size, + stride=1, + groups=1, + act=None, + name=None): + super(ConvBNLayer, self).__init__() + self.conv = nn.Conv2D( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=2 if stride == (1, 1) else kernel_size, + dilation=2 if stride == (1, 1) else 1, + stride=stride, + padding=(kernel_size - 1) // 2, + groups=groups, + weight_attr=ParamAttr(name=name + '.conv2d.output.1.w_0'), + bias_attr=False, ) + + if name == "conv1": + bn_name = "bn_" + name + else: + bn_name = "bn" + name[3:] + self.bn = nn.BatchNorm( + num_channels=out_channels, + act=act, + param_attr=ParamAttr(name=name + '.output.1.w_0'), + bias_attr=ParamAttr(name=name + '.output.1.b_0'), + moving_mean_name=bn_name + "_mean", + moving_variance_name=bn_name + "_variance") + + def __call__(self, x): + x = self.conv(x) + x = self.bn(x) + return x + + +class ShortCut(nn.Layer): + def __init__(self, in_channels, out_channels, stride, name, is_first=False): + super(ShortCut, self).__init__() + self.use_conv = True + + if in_channels != out_channels or stride != 1 or is_first == True: + if stride == (1, 1): + self.conv = ConvBNLayer( + in_channels, out_channels, 1, 1, name=name) + else: # stride==(2,2) + self.conv = ConvBNLayer( + in_channels, out_channels, 1, stride, name=name) + else: + self.use_conv = False + + def forward(self, x): + if self.use_conv: + x = self.conv(x) + return x + + +class BottleneckBlock(nn.Layer): + def __init__(self, in_channels, out_channels, stride, name): + super(BottleneckBlock, self).__init__() + self.conv0 = ConvBNLayer( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=1, + act='relu', + name=name + "_branch2a") + self.conv1 = ConvBNLayer( + in_channels=out_channels, + out_channels=out_channels, + kernel_size=3, + stride=stride, + act='relu', + name=name + "_branch2b") + + self.conv2 = ConvBNLayer( + in_channels=out_channels, + out_channels=out_channels * 4, + kernel_size=1, + act=None, + name=name + "_branch2c") + + self.short = ShortCut( + in_channels=in_channels, + out_channels=out_channels * 4, + stride=stride, + is_first=False, + name=name + "_branch1") + self.out_channels = out_channels * 4 + + def forward(self, x): + y = self.conv0(x) + y = self.conv1(y) + y = self.conv2(y) + y = y + self.short(x) + y = F.relu(y) + return y + + +class BasicBlock(nn.Layer): + def __init__(self, in_channels, out_channels, stride, name, is_first): + super(BasicBlock, self).__init__() + self.conv0 = ConvBNLayer( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=3, + act='relu', + stride=stride, + name=name + "_branch2a") + self.conv1 = ConvBNLayer( + in_channels=out_channels, + out_channels=out_channels, + kernel_size=3, + act=None, + name=name + "_branch2b") + self.short = ShortCut( + in_channels=in_channels, + out_channels=out_channels, + stride=stride, + is_first=is_first, + name=name + "_branch1") + self.out_channels = out_channels + + def forward(self, x): + y = self.conv0(x) + y = self.conv1(y) + y = y + self.short(x) + return F.relu(y) diff --git a/ppocr/modeling/heads/__init__.py b/ppocr/modeling/heads/__init__.py index 78074709..1a39ca41 100755 --- a/ppocr/modeling/heads/__init__.py +++ b/ppocr/modeling/heads/__init__.py @@ -23,10 +23,13 @@ def build_head(config): # rec head from .rec_ctc_head import CTCHead + from .rec_srn_head import SRNHead # cls head from .cls_head import ClsHead - support_dict = ['DBHead', 'EASTHead', 'SASTHead', 'CTCHead', 'ClsHead'] + support_dict = [ + 'DBHead', 'EASTHead', 'SASTHead', 'CTCHead', 'ClsHead', 'SRNHead' + ] module_name = config.pop('name') assert module_name in support_dict, Exception('head only support {}'.format( diff --git a/ppocr/modeling/heads/rec_srn_head.py b/ppocr/modeling/heads/rec_srn_head.py new file mode 100644 index 00000000..8aaf65e1 --- /dev/null +++ b/ppocr/modeling/heads/rec_srn_head.py @@ -0,0 +1,279 @@ +# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import math +import paddle +from paddle import nn, ParamAttr +from paddle.nn import functional as F +import paddle.fluid as fluid +import numpy as np +from .self_attention import WrapEncoderForFeature +from .self_attention import WrapEncoder +from paddle.static import Program +from ppocr.modeling.backbones.rec_resnet_fpn import ResNetFPN +import paddle.fluid.framework as framework + +from collections import OrderedDict +gradient_clip = 10 + + +class PVAM(nn.Layer): + def __init__(self, in_channels, char_num, max_text_length, num_heads, + num_encoder_tus, hidden_dims): + super(PVAM, self).__init__() + self.char_num = char_num + self.max_length = max_text_length + self.num_heads = num_heads + self.num_encoder_TUs = num_encoder_tus + self.hidden_dims = hidden_dims + # Transformer encoder + t = 256 + c = 512 + self.wrap_encoder_for_feature = WrapEncoderForFeature( + src_vocab_size=1, + max_length=t, + n_layer=self.num_encoder_TUs, + n_head=self.num_heads, + d_key=int(self.hidden_dims / self.num_heads), + d_value=int(self.hidden_dims / self.num_heads), + d_model=self.hidden_dims, + d_inner_hid=self.hidden_dims, + prepostprocess_dropout=0.1, + attention_dropout=0.1, + relu_dropout=0.1, + preprocess_cmd="n", + postprocess_cmd="da", + weight_sharing=True) + + # PVAM + self.flatten0 = paddle.nn.Flatten(start_axis=0, stop_axis=1) + self.fc0 = paddle.nn.Linear( + in_features=in_channels, + out_features=in_channels, ) + self.emb = paddle.nn.Embedding( + num_embeddings=self.max_length, embedding_dim=in_channels) + self.flatten1 = paddle.nn.Flatten(start_axis=0, stop_axis=2) + self.fc1 = paddle.nn.Linear( + in_features=in_channels, out_features=1, bias_attr=False) + + def forward(self, inputs, encoder_word_pos, gsrm_word_pos): + b, c, h, w = inputs.shape + conv_features = paddle.reshape(inputs, shape=[-1, c, h * w]) + conv_features = paddle.transpose(conv_features, perm=[0, 2, 1]) + # transformer encoder + b, t, c = conv_features.shape + + enc_inputs = [conv_features, encoder_word_pos, None] + word_features = self.wrap_encoder_for_feature(enc_inputs) + + # pvam + b, t, c = word_features.shape + word_features = self.fc0(word_features) + word_features_ = paddle.reshape(word_features, [-1, 1, t, c]) + word_features_ = paddle.tile(word_features_, [1, self.max_length, 1, 1]) + word_pos_feature = self.emb(gsrm_word_pos) + word_pos_feature_ = paddle.reshape(word_pos_feature, + [-1, self.max_length, 1, c]) + word_pos_feature_ = paddle.tile(word_pos_feature_, [1, 1, t, 1]) + y = word_pos_feature_ + word_features_ + y = F.tanh(y) + attention_weight = self.fc1(y) + attention_weight = paddle.reshape( + attention_weight, shape=[-1, self.max_length, t]) + attention_weight = F.softmax(attention_weight, axis=-1) + pvam_features = paddle.matmul(attention_weight, + word_features) #[b, max_length, c] + return pvam_features + + +class GSRM(nn.Layer): + def __init__(self, in_channels, char_num, max_text_length, num_heads, + num_encoder_tus, num_decoder_tus, hidden_dims): + super(GSRM, self).__init__() + self.char_num = char_num + self.max_length = max_text_length + self.num_heads = num_heads + self.num_encoder_TUs = num_encoder_tus + self.num_decoder_TUs = num_decoder_tus + self.hidden_dims = hidden_dims + + self.fc0 = paddle.nn.Linear( + in_features=in_channels, out_features=self.char_num) + self.wrap_encoder0 = WrapEncoder( + src_vocab_size=self.char_num + 1, + max_length=self.max_length, + n_layer=self.num_decoder_TUs, + n_head=self.num_heads, + d_key=int(self.hidden_dims / self.num_heads), + d_value=int(self.hidden_dims / self.num_heads), + d_model=self.hidden_dims, + d_inner_hid=self.hidden_dims, + prepostprocess_dropout=0.1, + attention_dropout=0.1, + relu_dropout=0.1, + preprocess_cmd="n", + postprocess_cmd="da", + weight_sharing=True) + + self.wrap_encoder1 = WrapEncoder( + src_vocab_size=self.char_num + 1, + max_length=self.max_length, + n_layer=self.num_decoder_TUs, + n_head=self.num_heads, + d_key=int(self.hidden_dims / self.num_heads), + d_value=int(self.hidden_dims / self.num_heads), + d_model=self.hidden_dims, + d_inner_hid=self.hidden_dims, + prepostprocess_dropout=0.1, + attention_dropout=0.1, + relu_dropout=0.1, + preprocess_cmd="n", + postprocess_cmd="da", + weight_sharing=True) + + self.mul = lambda x: paddle.matmul(x=x, + y=self.wrap_encoder0.prepare_decoder.emb0.weight, + transpose_y=True) + + def forward(self, inputs, gsrm_word_pos, gsrm_slf_attn_bias1, + gsrm_slf_attn_bias2): + # ===== GSRM Visual-to-semantic embedding block ===== + b, t, c = inputs.shape + pvam_features = paddle.reshape(inputs, [-1, c]) + word_out = self.fc0(pvam_features) + word_ids = paddle.argmax(F.softmax(word_out), axis=1) + word_ids = paddle.reshape(x=word_ids, shape=[-1, t, 1]) + + #===== GSRM Semantic reasoning block ===== + """ + This module is achieved through bi-transformers, + ngram_feature1 is the froward one, ngram_fetaure2 is the backward one + """ + pad_idx = self.char_num + + word1 = paddle.cast(word_ids, "float32") + word1 = F.pad(word1, [1, 0], value=1.0 * pad_idx, data_format="NLC") + word1 = paddle.cast(word1, "int64") + word1 = word1[:, :-1, :] + word2 = word_ids + + enc_inputs_1 = [word1, gsrm_word_pos, gsrm_slf_attn_bias1] + enc_inputs_2 = [word2, gsrm_word_pos, gsrm_slf_attn_bias2] + + gsrm_feature1 = self.wrap_encoder0(enc_inputs_1) + gsrm_feature2 = self.wrap_encoder1(enc_inputs_2) + + gsrm_feature2 = F.pad(gsrm_feature2, [0, 1], + value=0., + data_format="NLC") + gsrm_feature2 = gsrm_feature2[:, 1:, ] + gsrm_features = gsrm_feature1 + gsrm_feature2 + + gsrm_out = self.mul(gsrm_features) + + b, t, c = gsrm_out.shape + gsrm_out = paddle.reshape(gsrm_out, [-1, c]) + + return gsrm_features, word_out, gsrm_out + + +class VSFD(nn.Layer): + def __init__(self, in_channels=512, pvam_ch=512, char_num=38): + super(VSFD, self).__init__() + self.char_num = char_num + self.fc0 = paddle.nn.Linear( + in_features=in_channels * 2, out_features=pvam_ch) + self.fc1 = paddle.nn.Linear( + in_features=pvam_ch, out_features=self.char_num) + + def forward(self, pvam_feature, gsrm_feature): + b, t, c1 = pvam_feature.shape + b, t, c2 = gsrm_feature.shape + combine_feature_ = paddle.concat([pvam_feature, gsrm_feature], axis=2) + img_comb_feature_ = paddle.reshape( + combine_feature_, shape=[-1, c1 + c2]) + img_comb_feature_map = self.fc0(img_comb_feature_) + img_comb_feature_map = F.sigmoid(img_comb_feature_map) + img_comb_feature_map = paddle.reshape( + img_comb_feature_map, shape=[-1, t, c1]) + combine_feature = img_comb_feature_map * pvam_feature + ( + 1.0 - img_comb_feature_map) * gsrm_feature + img_comb_feature = paddle.reshape(combine_feature, shape=[-1, c1]) + + out = self.fc1(img_comb_feature) + return out + + +class SRNHead(nn.Layer): + def __init__(self, in_channels, out_channels, max_text_length, num_heads, + num_encoder_TUs, num_decoder_TUs, hidden_dims, **kwargs): + super(SRNHead, self).__init__() + self.char_num = out_channels + self.max_length = max_text_length + self.num_heads = num_heads + self.num_encoder_TUs = num_encoder_TUs + self.num_decoder_TUs = num_decoder_TUs + self.hidden_dims = hidden_dims + + self.pvam = PVAM( + in_channels=in_channels, + char_num=self.char_num, + max_text_length=self.max_length, + num_heads=self.num_heads, + num_encoder_tus=self.num_encoder_TUs, + hidden_dims=self.hidden_dims) + + self.gsrm = GSRM( + in_channels=in_channels, + char_num=self.char_num, + max_text_length=self.max_length, + num_heads=self.num_heads, + num_encoder_tus=self.num_encoder_TUs, + num_decoder_tus=self.num_decoder_TUs, + hidden_dims=self.hidden_dims) + self.vsfd = VSFD(in_channels=in_channels) + + self.gsrm.wrap_encoder1.prepare_decoder.emb0 = self.gsrm.wrap_encoder0.prepare_decoder.emb0 + + def forward(self, inputs, others): + encoder_word_pos = others[0] + gsrm_word_pos = others[1] + gsrm_slf_attn_bias1 = others[2] + gsrm_slf_attn_bias2 = others[3] + + pvam_feature = self.pvam(inputs, encoder_word_pos, gsrm_word_pos) + + gsrm_feature, word_out, gsrm_out = self.gsrm( + pvam_feature, gsrm_word_pos, gsrm_slf_attn_bias1, + gsrm_slf_attn_bias2) + + final_out = self.vsfd(pvam_feature, gsrm_feature) + if not self.training: + final_out = F.softmax(final_out, axis=1) + + _, decoded_out = paddle.topk(final_out, k=1) + + predicts = OrderedDict([ + ('predict', final_out), + ('pvam_feature', pvam_feature), + ('decoded_out', decoded_out), + ('word_out', word_out), + ('gsrm_out', gsrm_out), + ]) + + return predicts diff --git a/ppocr/modeling/heads/self_attention.py b/ppocr/modeling/heads/self_attention.py new file mode 100644 index 00000000..6aeb8f0c --- /dev/null +++ b/ppocr/modeling/heads/self_attention.py @@ -0,0 +1,408 @@ +# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import math + +import paddle +from paddle import ParamAttr, nn +from paddle import nn, ParamAttr +from paddle.nn import functional as F +import paddle.fluid as fluid +import numpy as np +gradient_clip = 10 + + +class WrapEncoderForFeature(nn.Layer): + def __init__(self, + src_vocab_size, + max_length, + n_layer, + n_head, + d_key, + d_value, + d_model, + d_inner_hid, + prepostprocess_dropout, + attention_dropout, + relu_dropout, + preprocess_cmd, + postprocess_cmd, + weight_sharing, + bos_idx=0): + super(WrapEncoderForFeature, self).__init__() + + self.prepare_encoder = PrepareEncoder( + src_vocab_size, + d_model, + max_length, + prepostprocess_dropout, + bos_idx=bos_idx, + word_emb_param_name="src_word_emb_table") + self.encoder = Encoder(n_layer, n_head, d_key, d_value, d_model, + d_inner_hid, prepostprocess_dropout, + attention_dropout, relu_dropout, preprocess_cmd, + postprocess_cmd) + + def forward(self, enc_inputs): + conv_features, src_pos, src_slf_attn_bias = enc_inputs + enc_input = self.prepare_encoder(conv_features, src_pos) + enc_output = self.encoder(enc_input, src_slf_attn_bias) + return enc_output + + +class WrapEncoder(nn.Layer): + """ + embedder + encoder + """ + + def __init__(self, + src_vocab_size, + max_length, + n_layer, + n_head, + d_key, + d_value, + d_model, + d_inner_hid, + prepostprocess_dropout, + attention_dropout, + relu_dropout, + preprocess_cmd, + postprocess_cmd, + weight_sharing, + bos_idx=0): + super(WrapEncoder, self).__init__() + + self.prepare_decoder = PrepareDecoder( + src_vocab_size, + d_model, + max_length, + prepostprocess_dropout, + bos_idx=bos_idx) + self.encoder = Encoder(n_layer, n_head, d_key, d_value, d_model, + d_inner_hid, prepostprocess_dropout, + attention_dropout, relu_dropout, preprocess_cmd, + postprocess_cmd) + + def forward(self, enc_inputs): + src_word, src_pos, src_slf_attn_bias = enc_inputs + enc_input = self.prepare_decoder(src_word, src_pos) + enc_output = self.encoder(enc_input, src_slf_attn_bias) + return enc_output + + +class Encoder(nn.Layer): + """ + encoder + """ + + def __init__(self, + n_layer, + n_head, + d_key, + d_value, + d_model, + d_inner_hid, + prepostprocess_dropout, + attention_dropout, + relu_dropout, + preprocess_cmd="n", + postprocess_cmd="da"): + + super(Encoder, self).__init__() + + self.encoder_layers = list() + for i in range(n_layer): + self.encoder_layers.append( + self.add_sublayer( + "layer_%d" % i, + EncoderLayer(n_head, d_key, d_value, d_model, d_inner_hid, + prepostprocess_dropout, attention_dropout, + relu_dropout, preprocess_cmd, + postprocess_cmd))) + self.processer = PrePostProcessLayer(preprocess_cmd, d_model, + prepostprocess_dropout) + + def forward(self, enc_input, attn_bias): + for encoder_layer in self.encoder_layers: + enc_output = encoder_layer(enc_input, attn_bias) + enc_input = enc_output + enc_output = self.processer(enc_output) + return enc_output + + +class EncoderLayer(nn.Layer): + """ + EncoderLayer + """ + + def __init__(self, + n_head, + d_key, + d_value, + d_model, + d_inner_hid, + prepostprocess_dropout, + attention_dropout, + relu_dropout, + preprocess_cmd="n", + postprocess_cmd="da"): + + super(EncoderLayer, self).__init__() + self.preprocesser1 = PrePostProcessLayer(preprocess_cmd, d_model, + prepostprocess_dropout) + self.self_attn = MultiHeadAttention(d_key, d_value, d_model, n_head, + attention_dropout) + self.postprocesser1 = PrePostProcessLayer(postprocess_cmd, d_model, + prepostprocess_dropout) + + self.preprocesser2 = PrePostProcessLayer(preprocess_cmd, d_model, + prepostprocess_dropout) + self.ffn = FFN(d_inner_hid, d_model, relu_dropout) + self.postprocesser2 = PrePostProcessLayer(postprocess_cmd, d_model, + prepostprocess_dropout) + + def forward(self, enc_input, attn_bias): + attn_output = self.self_attn( + self.preprocesser1(enc_input), None, None, attn_bias) + attn_output = self.postprocesser1(attn_output, enc_input) + ffn_output = self.ffn(self.preprocesser2(attn_output)) + ffn_output = self.postprocesser2(ffn_output, attn_output) + return ffn_output + + +class MultiHeadAttention(nn.Layer): + """ + Multi-Head Attention + """ + + def __init__(self, d_key, d_value, d_model, n_head=1, dropout_rate=0.): + super(MultiHeadAttention, self).__init__() + self.n_head = n_head + self.d_key = d_key + self.d_value = d_value + self.d_model = d_model + self.dropout_rate = dropout_rate + self.q_fc = paddle.nn.Linear( + in_features=d_model, out_features=d_key * n_head, bias_attr=False) + self.k_fc = paddle.nn.Linear( + in_features=d_model, out_features=d_key * n_head, bias_attr=False) + self.v_fc = paddle.nn.Linear( + in_features=d_model, out_features=d_value * n_head, bias_attr=False) + self.proj_fc = paddle.nn.Linear( + in_features=d_value * n_head, out_features=d_model, bias_attr=False) + + def _prepare_qkv(self, queries, keys, values, cache=None): + if keys is None: # self-attention + keys, values = queries, queries + static_kv = False + else: # cross-attention + static_kv = True + + q = self.q_fc(queries) + q = paddle.reshape(x=q, shape=[0, 0, self.n_head, self.d_key]) + q = paddle.transpose(x=q, perm=[0, 2, 1, 3]) + + if cache is not None and static_kv and "static_k" in cache: + # for encoder-decoder attention in inference and has cached + k = cache["static_k"] + v = cache["static_v"] + else: + k = self.k_fc(keys) + v = self.v_fc(values) + k = paddle.reshape(x=k, shape=[0, 0, self.n_head, self.d_key]) + k = paddle.transpose(x=k, perm=[0, 2, 1, 3]) + v = paddle.reshape(x=v, shape=[0, 0, self.n_head, self.d_value]) + v = paddle.transpose(x=v, perm=[0, 2, 1, 3]) + + if cache is not None: + if static_kv and not "static_k" in cache: + # for encoder-decoder attention in inference and has not cached + cache["static_k"], cache["static_v"] = k, v + elif not static_kv: + # for decoder self-attention in inference + cache_k, cache_v = cache["k"], cache["v"] + k = paddle.concat([cache_k, k], axis=2) + v = paddle.concat([cache_v, v], axis=2) + cache["k"], cache["v"] = k, v + + return q, k, v + + def forward(self, queries, keys, values, attn_bias, cache=None): + # compute q ,k ,v + keys = queries if keys is None else keys + values = keys if values is None else values + q, k, v = self._prepare_qkv(queries, keys, values, cache) + + # scale dot product attention + product = paddle.matmul(x=q, y=k, transpose_y=True) + product = product * self.d_model**-0.5 + if attn_bias is not None: + product += attn_bias + weights = F.softmax(product) + if self.dropout_rate: + weights = F.dropout( + weights, p=self.dropout_rate, mode="downscale_in_infer") + out = paddle.matmul(weights, v) + + # combine heads + out = paddle.transpose(out, perm=[0, 2, 1, 3]) + out = paddle.reshape(x=out, shape=[0, 0, out.shape[2] * out.shape[3]]) + + # project to output + out = self.proj_fc(out) + + return out + + +class PrePostProcessLayer(nn.Layer): + """ + PrePostProcessLayer + """ + + def __init__(self, process_cmd, d_model, dropout_rate): + super(PrePostProcessLayer, self).__init__() + self.process_cmd = process_cmd + self.functors = [] + for cmd in self.process_cmd: + if cmd == "a": # add residual connection + self.functors.append(lambda x, y: x + y if y is not None else x) + elif cmd == "n": # add layer normalization + self.functors.append( + self.add_sublayer( + "layer_norm_%d" % len( + self.sublayers(include_sublayers=False)), + paddle.nn.LayerNorm( + normalized_shape=d_model, + weight_attr=fluid.ParamAttr( + initializer=fluid.initializer.Constant(1.)), + bias_attr=fluid.ParamAttr( + initializer=fluid.initializer.Constant(0.))))) + elif cmd == "d": # add dropout + self.functors.append(lambda x: F.dropout( + x, p=dropout_rate, mode="downscale_in_infer") + if dropout_rate else x) + + def forward(self, x, residual=None): + for i, cmd in enumerate(self.process_cmd): + if cmd == "a": + x = self.functors[i](x, residual) + else: + x = self.functors[i](x) + return x + + +class PrepareEncoder(nn.Layer): + def __init__(self, + src_vocab_size, + src_emb_dim, + src_max_len, + dropout_rate=0, + bos_idx=0, + word_emb_param_name=None, + pos_enc_param_name=None): + super(PrepareEncoder, self).__init__() + self.src_emb_dim = src_emb_dim + self.src_max_len = src_max_len + self.emb = paddle.nn.Embedding( + num_embeddings=self.src_max_len, + embedding_dim=self.src_emb_dim, + sparse=True) + self.dropout_rate = dropout_rate + + def forward(self, src_word, src_pos): + src_word_emb = src_word + src_word_emb = fluid.layers.cast(src_word_emb, 'float32') + src_word_emb = paddle.scale(x=src_word_emb, scale=self.src_emb_dim**0.5) + src_pos = paddle.squeeze(src_pos, axis=-1) + src_pos_enc = self.emb(src_pos) + src_pos_enc.stop_gradient = True + enc_input = src_word_emb + src_pos_enc + if self.dropout_rate: + out = F.dropout( + x=enc_input, p=self.dropout_rate, mode="downscale_in_infer") + else: + out = enc_input + return out + + +class PrepareDecoder(nn.Layer): + def __init__(self, + src_vocab_size, + src_emb_dim, + src_max_len, + dropout_rate=0, + bos_idx=0, + word_emb_param_name=None, + pos_enc_param_name=None): + super(PrepareDecoder, self).__init__() + self.src_emb_dim = src_emb_dim + """ + self.emb0 = Embedding(num_embeddings=src_vocab_size, + embedding_dim=src_emb_dim) + """ + self.emb0 = paddle.nn.Embedding( + num_embeddings=src_vocab_size, + embedding_dim=self.src_emb_dim, + weight_attr=paddle.ParamAttr( + name=word_emb_param_name, + initializer=nn.initializer.Normal(0., src_emb_dim**-0.5))) + self.emb1 = paddle.nn.Embedding( + num_embeddings=src_max_len, + embedding_dim=self.src_emb_dim, + weight_attr=paddle.ParamAttr(name=pos_enc_param_name)) + self.dropout_rate = dropout_rate + + def forward(self, src_word, src_pos): + src_word = fluid.layers.cast(src_word, 'int64') + src_word = paddle.squeeze(src_word, axis=-1) + src_word_emb = self.emb0(src_word) + src_word_emb = paddle.scale(x=src_word_emb, scale=self.src_emb_dim**0.5) + src_pos = paddle.squeeze(src_pos, axis=-1) + src_pos_enc = self.emb1(src_pos) + src_pos_enc.stop_gradient = True + enc_input = src_word_emb + src_pos_enc + if self.dropout_rate: + out = F.dropout( + x=enc_input, p=self.dropout_rate, mode="downscale_in_infer") + else: + out = enc_input + return out + + +class FFN(nn.Layer): + """ + Feed-Forward Network + """ + + def __init__(self, d_inner_hid, d_model, dropout_rate): + super(FFN, self).__init__() + self.dropout_rate = dropout_rate + self.fc1 = paddle.nn.Linear( + in_features=d_model, out_features=d_inner_hid) + self.fc2 = paddle.nn.Linear( + in_features=d_inner_hid, out_features=d_model) + + def forward(self, x): + hidden = self.fc1(x) + hidden = F.relu(hidden) + if self.dropout_rate: + hidden = F.dropout( + hidden, p=self.dropout_rate, mode="downscale_in_infer") + out = self.fc2(hidden) + return out diff --git a/ppocr/postprocess/__init__.py b/ppocr/postprocess/__init__.py index c9b42e08..0156e438 100644 --- a/ppocr/postprocess/__init__.py +++ b/ppocr/postprocess/__init__.py @@ -26,11 +26,12 @@ def build_post_process(config, global_config=None): from .db_postprocess import DBPostProcess from .east_postprocess import EASTPostProcess from .sast_postprocess import SASTPostProcess - from .rec_postprocess import CTCLabelDecode, AttnLabelDecode + from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode from .cls_postprocess import ClsPostProcess support_dict = [ - 'DBPostProcess', 'EASTPostProcess', 'SASTPostProcess', 'CTCLabelDecode', 'AttnLabelDecode', 'ClsPostProcess' + 'DBPostProcess', 'EASTPostProcess', 'SASTPostProcess', 'CTCLabelDecode', + 'AttnLabelDecode', 'ClsPostProcess', 'SRNLabelDecode' ] config = copy.deepcopy(config) diff --git a/ppocr/postprocess/rec_postprocess.py b/ppocr/postprocess/rec_postprocess.py index a18e101b..c2303cea 100644 --- a/ppocr/postprocess/rec_postprocess.py +++ b/ppocr/postprocess/rec_postprocess.py @@ -29,6 +29,9 @@ class BaseRecLabelDecode(object): assert character_type in support_character_type, "Only {} are supported now but get {}".format( support_character_type, character_type) + self.beg_str = "sos" + self.end_str = "eos" + if character_type == "en": self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz" dict_character = list(self.character_str) @@ -104,7 +107,6 @@ class CTCLabelDecode(BaseRecLabelDecode): def __call__(self, preds, label=None, *args, **kwargs): if isinstance(preds, paddle.Tensor): preds = preds.numpy() - preds_idx = preds.argmax(axis=2) preds_prob = preds.max(axis=2) text = self.decode(preds_idx, preds_prob) @@ -153,3 +155,83 @@ class AttnLabelDecode(BaseRecLabelDecode): assert False, "unsupport type %s in get_beg_end_flag_idx" \ % beg_or_end return idx + + +class SRNLabelDecode(BaseRecLabelDecode): + """ Convert between text-label and text-index """ + + def __init__(self, + character_dict_path=None, + character_type='en', + use_space_char=False, + **kwargs): + super(SRNLabelDecode, self).__init__(character_dict_path, + character_type, use_space_char) + + def __call__(self, preds, label=None, *args, **kwargs): + pred = preds['predict'] + char_num = len(self.character_str) + 2 + if isinstance(pred, paddle.Tensor): + pred = pred.numpy() + pred = np.reshape(pred, [-1, char_num]) + + preds_idx = np.argmax(pred, axis=1) + preds_prob = np.max(pred, axis=1) + + preds_idx = np.reshape(preds_idx, [-1, 25]) + + preds_prob = np.reshape(preds_prob, [-1, 25]) + + text = self.decode(preds_idx, preds_prob) + + if label is None: + return text + label = self.decode(label, is_remove_duplicate=False) + return text, label + + def decode(self, text_index, text_prob=None, is_remove_duplicate=True): + """ convert text-index into text-label. """ + result_list = [] + ignored_tokens = self.get_ignored_tokens() + batch_size = len(text_index) + + for batch_idx in range(batch_size): + char_list = [] + conf_list = [] + for idx in range(len(text_index[batch_idx])): + if text_index[batch_idx][idx] in ignored_tokens: + continue + if is_remove_duplicate: + # only for predict + if idx > 0 and text_index[batch_idx][idx - 1] == text_index[ + batch_idx][idx]: + continue + char_list.append(self.character[int(text_index[batch_idx][ + idx])]) + if text_prob is not None: + conf_list.append(text_prob[batch_idx][idx]) + else: + conf_list.append(1) + + text = ''.join(char_list) + result_list.append((text, np.mean(conf_list))) + return result_list + + def add_special_char(self, dict_character): + dict_character = dict_character + [self.beg_str, self.end_str] + return dict_character + + def get_ignored_tokens(self): + beg_idx = self.get_beg_end_flag_idx("beg") + end_idx = self.get_beg_end_flag_idx("end") + return [beg_idx, end_idx] + + def get_beg_end_flag_idx(self, beg_or_end): + if beg_or_end == "beg": + idx = np.array(self.dict[self.beg_str]) + elif beg_or_end == "end": + idx = np.array(self.dict[self.end_str]) + else: + assert False, "unsupport type %s in get_beg_end_flag_idx" \ + % beg_or_end + return idx diff --git a/tools/export_model.py b/tools/export_model.py index 74357d58..58dc0def 100755 --- a/tools/export_model.py +++ b/tools/export_model.py @@ -31,6 +31,14 @@ from ppocr.utils.logging import get_logger from tools.program import load_config, merge_config, ArgsParser +def parse_args(): + parser = argparse.ArgumentParser() + parser.add_argument("-c", "--config", help="configuration file to use") + parser.add_argument( + "-o", "--output_path", type=str, default='./output/infer/') + return parser.parse_args() + + def main(): FLAGS = ArgsParser().parse_args() config = load_config(FLAGS.config) @@ -51,14 +59,33 @@ def main(): model.eval() save_path = '{}/inference'.format(config['Global']['save_inference_dir']) - infer_shape = [3, 32, 100] if config['Architecture'][ - 'model_type'] != "det" else [3, 640, 640] - model = to_static( - model, - input_spec=[ + + if config['Architecture']['algorithm'] == "SRN": + other_shape = [ paddle.static.InputSpec( - shape=[None] + infer_shape, dtype='float32') - ]) + shape=[None, 1, 64, 256], dtype='float32'), [ + paddle.static.InputSpec( + shape=[None, 256, 1], + dtype="int64"), paddle.static.InputSpec( + shape=[None, 25, 1], + dtype="int64"), paddle.static.InputSpec( + shape=[None, 8, 25, 25], dtype="int64"), + paddle.static.InputSpec( + shape=[None, 8, 25, 25], dtype="int64") + ] + ] + model = to_static(model, input_spec=other_shape) + + else: + infer_shape = [3, 32, 100] if config['Architecture'][ + 'model_type'] != "det" else [3, 640, 640] + model = to_static( + model, + input_spec=[ + paddle.static.InputSpec( + shape=[None] + infer_shape, dtype='float32') + ]) + paddle.jit.save(model, save_path) logger.info('inference model is saved to {}'.format(save_path)) diff --git a/tools/infer/predict_rec.py b/tools/infer/predict_rec.py index 974fdbb6..fd895e50 100755 --- a/tools/infer/predict_rec.py +++ b/tools/infer/predict_rec.py @@ -25,6 +25,7 @@ import numpy as np import math import time import traceback +import paddle import tools.infer.utility as utility from ppocr.postprocess import build_post_process @@ -46,6 +47,13 @@ class TextRecognizer(object): "character_dict_path": args.rec_char_dict_path, "use_space_char": args.use_space_char } + if self.rec_algorithm == "SRN": + postprocess_params = { + 'name': 'SRNLabelDecode', + "character_type": args.rec_char_type, + "character_dict_path": args.rec_char_dict_path, + "use_space_char": args.use_space_char + } self.postprocess_op = build_post_process(postprocess_params) self.predictor, self.input_tensor, self.output_tensors = \ utility.create_predictor(args, 'rec', logger) @@ -70,6 +78,78 @@ class TextRecognizer(object): padding_im[:, :, 0:resized_w] = resized_image return padding_im + def resize_norm_img_srn(self, img, image_shape): + imgC, imgH, imgW = image_shape + + img_black = np.zeros((imgH, imgW)) + im_hei = img.shape[0] + im_wid = img.shape[1] + + if im_wid <= im_hei * 1: + img_new = cv2.resize(img, (imgH * 1, imgH)) + elif im_wid <= im_hei * 2: + img_new = cv2.resize(img, (imgH * 2, imgH)) + elif im_wid <= im_hei * 3: + img_new = cv2.resize(img, (imgH * 3, imgH)) + else: + img_new = cv2.resize(img, (imgW, imgH)) + + img_np = np.asarray(img_new) + img_np = cv2.cvtColor(img_np, cv2.COLOR_BGR2GRAY) + img_black[:, 0:img_np.shape[1]] = img_np + img_black = img_black[:, :, np.newaxis] + + row, col, c = img_black.shape + c = 1 + + return np.reshape(img_black, (c, row, col)).astype(np.float32) + + def srn_other_inputs(self, image_shape, num_heads, max_text_length): + + imgC, imgH, imgW = image_shape + feature_dim = int((imgH / 8) * (imgW / 8)) + + encoder_word_pos = np.array(range(0, feature_dim)).reshape( + (feature_dim, 1)).astype('int64') + gsrm_word_pos = np.array(range(0, max_text_length)).reshape( + (max_text_length, 1)).astype('int64') + + gsrm_attn_bias_data = np.ones((1, max_text_length, max_text_length)) + gsrm_slf_attn_bias1 = np.triu(gsrm_attn_bias_data, 1).reshape( + [-1, 1, max_text_length, max_text_length]) + gsrm_slf_attn_bias1 = np.tile( + gsrm_slf_attn_bias1, + [1, num_heads, 1, 1]).astype('float32') * [-1e9] + + gsrm_slf_attn_bias2 = np.tril(gsrm_attn_bias_data, -1).reshape( + [-1, 1, max_text_length, max_text_length]) + gsrm_slf_attn_bias2 = np.tile( + gsrm_slf_attn_bias2, + [1, num_heads, 1, 1]).astype('float32') * [-1e9] + + encoder_word_pos = encoder_word_pos[np.newaxis, :] + gsrm_word_pos = gsrm_word_pos[np.newaxis, :] + + return [ + encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1, + gsrm_slf_attn_bias2 + ] + + def process_image_srn(self, img, image_shape, num_heads, max_text_length): + norm_img = self.resize_norm_img_srn(img, image_shape) + norm_img = norm_img[np.newaxis, :] + + [encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1, gsrm_slf_attn_bias2] = \ + self.srn_other_inputs(image_shape, num_heads, max_text_length) + + gsrm_slf_attn_bias1 = gsrm_slf_attn_bias1.astype(np.float32) + gsrm_slf_attn_bias2 = gsrm_slf_attn_bias2.astype(np.float32) + encoder_word_pos = encoder_word_pos.astype(np.int64) + gsrm_word_pos = gsrm_word_pos.astype(np.int64) + + return (norm_img, encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1, + gsrm_slf_attn_bias2) + def __call__(self, img_list): img_num = len(img_list) # Calculate the aspect ratio of all text bars @@ -93,21 +173,64 @@ class TextRecognizer(object): wh_ratio = w * 1.0 / h max_wh_ratio = max(max_wh_ratio, wh_ratio) for ino in range(beg_img_no, end_img_no): - # norm_img = self.resize_norm_img(img_list[ino], max_wh_ratio) - norm_img = self.resize_norm_img(img_list[indices[ino]], - max_wh_ratio) - norm_img = norm_img[np.newaxis, :] - norm_img_batch.append(norm_img) + if self.rec_algorithm != "SRN": + norm_img = self.resize_norm_img(img_list[indices[ino]], + max_wh_ratio) + norm_img = norm_img[np.newaxis, :] + norm_img_batch.append(norm_img) + else: + norm_img = self.process_image_srn( + img_list[indices[ino]], self.rec_image_shape, 8, 25) + encoder_word_pos_list = [] + gsrm_word_pos_list = [] + gsrm_slf_attn_bias1_list = [] + gsrm_slf_attn_bias2_list = [] + encoder_word_pos_list.append(norm_img[1]) + gsrm_word_pos_list.append(norm_img[2]) + gsrm_slf_attn_bias1_list.append(norm_img[3]) + gsrm_slf_attn_bias2_list.append(norm_img[4]) + norm_img_batch.append(norm_img[0]) norm_img_batch = np.concatenate(norm_img_batch) norm_img_batch = norm_img_batch.copy() - starttime = time.time() - self.input_tensor.copy_from_cpu(norm_img_batch) - self.predictor.run() - outputs = [] - for output_tensor in self.output_tensors: - output = output_tensor.copy_to_cpu() - outputs.append(output) - preds = outputs[0] + + if self.rec_algorithm == "SRN": + starttime = time.time() + encoder_word_pos_list = np.concatenate(encoder_word_pos_list) + gsrm_word_pos_list = np.concatenate(gsrm_word_pos_list) + gsrm_slf_attn_bias1_list = np.concatenate( + gsrm_slf_attn_bias1_list) + gsrm_slf_attn_bias2_list = np.concatenate( + gsrm_slf_attn_bias2_list) + + inputs = [ + norm_img_batch, + encoder_word_pos_list, + gsrm_word_pos_list, + gsrm_slf_attn_bias1_list, + gsrm_slf_attn_bias2_list, + ] + input_names = self.predictor.get_input_names() + for i in range(len(input_names)): + input_tensor = self.predictor.get_input_handle(input_names[ + i]) + input_tensor.copy_from_cpu(inputs[i]) + self.predictor.run() + outputs = [] + for output_tensor in self.output_tensors: + output = output_tensor.copy_to_cpu() + outputs.append(output) + preds = {"predict": outputs[2]} + else: + starttime = time.time() + self.input_tensor.copy_from_cpu(norm_img_batch) + self.predictor.run() + + outputs = [] + for output_tensor in self.output_tensors: + output = output_tensor.copy_to_cpu() + outputs.append(output) + preds = outputs[0] + rec_result = self.postprocess_op(preds) for rno in range(len(rec_result)): rec_res[indices[beg_img_no + rno]] = rec_result[rno] diff --git a/tools/infer_rec.py b/tools/infer_rec.py index 7e4b0811..075ec261 100755 --- a/tools/infer_rec.py +++ b/tools/infer_rec.py @@ -62,7 +62,13 @@ def main(): elif op_name in ['RecResizeImg']: op[op_name]['infer_mode'] = True elif op_name == 'KeepKeys': - op[op_name]['keep_keys'] = ['image'] + if config['Architecture']['algorithm'] == "SRN": + op[op_name]['keep_keys'] = [ + 'image', 'encoder_word_pos', 'gsrm_word_pos', + 'gsrm_slf_attn_bias1', 'gsrm_slf_attn_bias2' + ] + else: + op[op_name]['keep_keys'] = ['image'] transforms.append(op) global_config['infer_mode'] = True ops = create_operators(transforms, global_config) @@ -74,10 +80,25 @@ def main(): img = f.read() data = {'image': img} batch = transform(data, ops) + if config['Architecture']['algorithm'] == "SRN": + encoder_word_pos_list = np.expand_dims(batch[1], axis=0) + gsrm_word_pos_list = np.expand_dims(batch[2], axis=0) + gsrm_slf_attn_bias1_list = np.expand_dims(batch[3], axis=0) + gsrm_slf_attn_bias2_list = np.expand_dims(batch[4], axis=0) + + others = [ + paddle.to_tensor(encoder_word_pos_list), + paddle.to_tensor(gsrm_word_pos_list), + paddle.to_tensor(gsrm_slf_attn_bias1_list), + paddle.to_tensor(gsrm_slf_attn_bias2_list) + ] images = np.expand_dims(batch[0], axis=0) images = paddle.to_tensor(images) - preds = model(images) + if config['Architecture']['algorithm'] == "SRN": + preds = model(images, others) + else: + preds = model(images) post_result = post_process_class(preds) for rec_reuslt in post_result: logger.info('\t result: {}'.format(rec_reuslt)) diff --git a/tools/program.py b/tools/program.py index c2915426..ce52a610 100755 --- a/tools/program.py +++ b/tools/program.py @@ -179,9 +179,9 @@ def train(config, if 'start_epoch' in best_model_dict: start_epoch = best_model_dict['start_epoch'] else: - start_epoch = 1 + start_epoch = 0 - for epoch in range(start_epoch, epoch_num + 1): + for epoch in range(start_epoch, epoch_num): if epoch > 0: train_dataloader = build_dataloader(config, 'Train', device, logger) train_batch_cost = 0.0 @@ -194,7 +194,11 @@ def train(config, break lr = optimizer.get_lr() images = batch[0] - preds = model(images) + if config['Architecture']['algorithm'] == "SRN": + others = batch[-4:] + preds = model(images, others) + else: + preds = model(images) loss = loss_class(preds, batch) avg_loss = loss['loss'] avg_loss.backward() @@ -212,6 +216,7 @@ def train(config, stats['lr'] = lr train_stats.update(stats) + #cal_metric_during_train = False if cal_metric_during_train: # onlt rec and cls need batch = [item.numpy() for item in batch] post_result = post_process_class(preds, batch[1]) @@ -312,8 +317,9 @@ def eval(model, valid_dataloader, post_process_class, eval_class): if idx >= len(valid_dataloader): break images = batch[0] + others = batch[-4:] start = time.time() - preds = model(images) + preds = model(images, others) batch = [item.numpy() for item in batch] # Obtain usable results from post-processing methods From 297871d4be965b760da6ed1535fad82354cfd366 Mon Sep 17 00:00:00 2001 From: tink2123 Date: Wed, 30 Dec 2020 19:54:16 +0800 Subject: [PATCH 02/10] fix bugs --- ppocr/metrics/__init__.py | 1 - tools/program.py | 5 ++--- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/ppocr/metrics/__init__.py b/ppocr/metrics/__init__.py index 41828f51..a0e7d912 100644 --- a/ppocr/metrics/__init__.py +++ b/ppocr/metrics/__init__.py @@ -26,7 +26,6 @@ def build_metric(config): from .det_metric import DetMetric from .rec_metric import RecMetric from .cls_metric import ClsMetric - from .rec_metric import RecMetric support_dict = ['DetMetric', 'RecMetric', 'ClsMetric'] diff --git a/tools/program.py b/tools/program.py index ce52a610..08bc4c81 100755 --- a/tools/program.py +++ b/tools/program.py @@ -179,9 +179,9 @@ def train(config, if 'start_epoch' in best_model_dict: start_epoch = best_model_dict['start_epoch'] else: - start_epoch = 0 + start_epoch = 1 - for epoch in range(start_epoch, epoch_num): + for epoch in range(start_epoch, epoch_num + 1): if epoch > 0: train_dataloader = build_dataloader(config, 'Train', device, logger) train_batch_cost = 0.0 @@ -216,7 +216,6 @@ def train(config, stats['lr'] = lr train_stats.update(stats) - #cal_metric_during_train = False if cal_metric_during_train: # onlt rec and cls need batch = [item.numpy() for item in batch] post_result = post_process_class(preds, batch[1]) From 841adff934bbef4967b64bbf029a9bc454578adf Mon Sep 17 00:00:00 2001 From: Karthikeyan Singaravelan Date: Thu, 31 Dec 2020 07:05:06 +0000 Subject: [PATCH 03/10] Fix syntax warning over comparison of literals using is. --- PPOCRLabel/PPOCRLabel.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/PPOCRLabel/PPOCRLabel.py b/PPOCRLabel/PPOCRLabel.py index b4c73083..4d2108e6 100644 --- a/PPOCRLabel/PPOCRLabel.py +++ b/PPOCRLabel/PPOCRLabel.py @@ -1032,7 +1032,7 @@ class MainWindow(QMainWindow, WindowMixin): for box in self.result_dic: trans_dic = {"label": box[1][0], "points": box[0], 'difficult': False} - if trans_dic["label"] is "" and mode == 'Auto': + if trans_dic["label"] == "" and mode == 'Auto': continue shapes.append(trans_dic) @@ -1791,7 +1791,7 @@ class MainWindow(QMainWindow, WindowMixin): QMessageBox.information(self, "Information", msg) return result = self.ocr.ocr(img_crop, cls=True, det=False) - if result[0][0] is not '': + if result[0][0] != '': result.insert(0, box) print('result in reRec is ', result) self.result_dic.append(result) @@ -1822,7 +1822,7 @@ class MainWindow(QMainWindow, WindowMixin): QMessageBox.information(self, "Information", msg) return result = self.ocr.ocr(img_crop, cls=True, det=False) - if result[0][0] is not '': + if result[0][0] != '': result.insert(0, box) print('result in reRec is ', result) if result[1][0] == shape.label: @@ -2008,7 +2008,7 @@ if __name__ == '__main__': resource_file = './libs/resources.py' if not os.path.exists(resource_file): output = os.system('pyrcc5 -o libs/resources.py resources.qrc') - assert output is 0, "operate the cmd have some problems ,please check whether there is a in the lib " \ + assert output == 0, "operate the cmd have some problems ,please check whether there is a in the lib " \ "directory resources.py " import libs.resources sys.exit(main()) From 93670ab5a2dc59d589f82e0c1a952e295ef3c86e Mon Sep 17 00:00:00 2001 From: tink2123 Date: Tue, 19 Jan 2021 06:48:52 +0000 Subject: [PATCH 04/10] all ready --- configs/rec/rec_r50_fpn_srn.yml | 9 +++++---- ppocr/modeling/heads/self_attention.py | 1 + ppocr/postprocess/rec_postprocess.py | 7 ++++--- tools/program.py | 7 +++++++ 4 files changed, 17 insertions(+), 7 deletions(-) diff --git a/configs/rec/rec_r50_fpn_srn.yml b/configs/rec/rec_r50_fpn_srn.yml index 78f8d551..ec7f1705 100644 --- a/configs/rec/rec_r50_fpn_srn.yml +++ b/configs/rec/rec_r50_fpn_srn.yml @@ -3,7 +3,7 @@ Global: epoch_num: 72 log_smooth_window: 20 print_batch_step: 5 - save_model_dir: ./output/rec/srn + save_model_dir: ./output/rec/srn_new save_epoch_step: 3 # evaluation is run every 5000 iterations after the 4000th iteration eval_batch_step: [0, 5000] @@ -25,8 +25,10 @@ Global: Optimizer: name: Adam + beta1: 0.9 + beta2: 0.999 + clip_norm: 10.0 lr: - name: Cosine learning_rate: 0.0001 Architecture: @@ -58,7 +60,6 @@ Train: dataset: name: LMDBDataSet data_dir: ./train_data/srn_train_data_duiqi - #label_file_list: ["./train_data/ic15_data/1.txt"] transforms: - DecodeImage: # load image img_mode: BGR @@ -77,7 +78,7 @@ Train: loader: shuffle: False batch_size_per_card: 64 - drop_last: True + drop_last: False num_workers: 4 Eval: diff --git a/ppocr/modeling/heads/self_attention.py b/ppocr/modeling/heads/self_attention.py index 6aeb8f0c..51d5198f 100644 --- a/ppocr/modeling/heads/self_attention.py +++ b/ppocr/modeling/heads/self_attention.py @@ -359,6 +359,7 @@ class PrepareDecoder(nn.Layer): self.emb0 = paddle.nn.Embedding( num_embeddings=src_vocab_size, embedding_dim=self.src_emb_dim, + padding_idx=bos_idx, weight_attr=paddle.ParamAttr( name=word_emb_param_name, initializer=nn.initializer.Normal(0., src_emb_dim**-0.5))) diff --git a/ppocr/postprocess/rec_postprocess.py b/ppocr/postprocess/rec_postprocess.py index c2303cea..867f920a 100644 --- a/ppocr/postprocess/rec_postprocess.py +++ b/ppocr/postprocess/rec_postprocess.py @@ -182,14 +182,15 @@ class SRNLabelDecode(BaseRecLabelDecode): preds_prob = np.reshape(preds_prob, [-1, 25]) - text = self.decode(preds_idx, preds_prob) + text = self.decode(preds_idx, preds_prob, is_remove_duplicate=True) if label is None: + text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False) return text - label = self.decode(label, is_remove_duplicate=False) + label = self.decode(label, is_remove_duplicate=True) return text, label - def decode(self, text_index, text_prob=None, is_remove_duplicate=True): + def decode(self, text_index, text_prob=None, is_remove_duplicate=False): """ convert text-index into text-label. """ result_list = [] ignored_tokens = self.get_ignored_tokens() diff --git a/tools/program.py b/tools/program.py index 08bc4c81..885d45f5 100755 --- a/tools/program.py +++ b/tools/program.py @@ -242,6 +242,12 @@ def train(config, # eval if global_step > start_eval_step and \ (global_step - start_eval_step) % eval_batch_step == 0 and dist.get_rank() == 0: + model_average = paddle.optimizer.ModelAverage( + 0.15, + parameters=model.parameters(), + min_average_window=10000, + max_average_window=15625) + model_average.apply() cur_metirc = eval(model, valid_dataloader, post_process_class, eval_class) cur_metirc_str = 'cur metirc, {}'.format(', '.join( @@ -277,6 +283,7 @@ def train(config, best_model_dict[main_indicator], global_step) global_step += 1 + optimizer.clear_grad() batch_start = time.time() if dist.get_rank() == 0: save_model( From ed2f0de95e58298ee733ee83976ef43079a613a0 Mon Sep 17 00:00:00 2001 From: tink2123 Date: Fri, 22 Jan 2021 03:15:56 +0000 Subject: [PATCH 05/10] mv model_average to incubate --- ppocr/losses/rec_srn_loss.py | 2 +- ppocr/postprocess/rec_postprocess.py | 4 ++-- tools/program.py | 15 +++++++++------ 3 files changed, 12 insertions(+), 9 deletions(-) diff --git a/ppocr/losses/rec_srn_loss.py b/ppocr/losses/rec_srn_loss.py index d722ee0f..7d5b65eb 100644 --- a/ppocr/losses/rec_srn_loss.py +++ b/ppocr/losses/rec_srn_loss.py @@ -42,6 +42,6 @@ class SRNLoss(nn.Layer): cost_gsrm = paddle.reshape(x=paddle.sum(cost_gsrm), shape=[1]) cost_vsfd = paddle.reshape(x=paddle.sum(cost_vsfd), shape=[1]) - sum_cost = cost_word + cost_vsfd * 2.0 + cost_gsrm * 0.15 + sum_cost = cost_word * 3.0 + cost_vsfd + cost_gsrm * 0.15 return {'loss': sum_cost, 'word_loss': cost_word, 'img_loss': cost_vsfd} diff --git a/ppocr/postprocess/rec_postprocess.py b/ppocr/postprocess/rec_postprocess.py index 867f920a..8c972a14 100644 --- a/ppocr/postprocess/rec_postprocess.py +++ b/ppocr/postprocess/rec_postprocess.py @@ -182,12 +182,12 @@ class SRNLabelDecode(BaseRecLabelDecode): preds_prob = np.reshape(preds_prob, [-1, 25]) - text = self.decode(preds_idx, preds_prob, is_remove_duplicate=True) + text = self.decode(preds_idx, preds_prob) if label is None: text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False) return text - label = self.decode(label, is_remove_duplicate=True) + label = self.decode(label) return text, label def decode(self, text_index, text_prob=None, is_remove_duplicate=False): diff --git a/tools/program.py b/tools/program.py index 885d45f5..f329dcd5 100755 --- a/tools/program.py +++ b/tools/program.py @@ -174,6 +174,7 @@ def train(config, best_model_dict = {main_indicator: 0} best_model_dict.update(pre_best_model_dict) train_stats = TrainingStats(log_smooth_window, ['lr']) + model_average = False model.train() if 'start_epoch' in best_model_dict: @@ -197,6 +198,7 @@ def train(config, if config['Architecture']['algorithm'] == "SRN": others = batch[-4:] preds = model(images, others) + model_average = True else: preds = model(images) loss = loss_class(preds, batch) @@ -242,12 +244,13 @@ def train(config, # eval if global_step > start_eval_step and \ (global_step - start_eval_step) % eval_batch_step == 0 and dist.get_rank() == 0: - model_average = paddle.optimizer.ModelAverage( - 0.15, - parameters=model.parameters(), - min_average_window=10000, - max_average_window=15625) - model_average.apply() + if model_average: + Model_Average = paddle.incubate.optimizer.ModelAverage( + 0.15, + parameters=model.parameters(), + min_average_window=10000, + max_average_window=15625) + Model_Average.apply() cur_metirc = eval(model, valid_dataloader, post_process_class, eval_class) cur_metirc_str = 'cur metirc, {}'.format(', '.join( From 647db30f6f3bd4d8e3693d6ba83b2d0fea355076 Mon Sep 17 00:00:00 2001 From: Leif <4603009@qq.com> Date: Fri, 29 Jan 2021 14:51:40 +0800 Subject: [PATCH 06/10] Fix bugs during save recognition results --- PPOCRLabel/PPOCRLabel.py | 30 +++++++++++++++++++----------- 1 file changed, 19 insertions(+), 11 deletions(-) diff --git a/PPOCRLabel/PPOCRLabel.py b/PPOCRLabel/PPOCRLabel.py index 4d9c5274..92d80c8a 100644 --- a/PPOCRLabel/PPOCRLabel.py +++ b/PPOCRLabel/PPOCRLabel.py @@ -1450,7 +1450,7 @@ class MainWindow(QMainWindow, WindowMixin): item = QListWidgetItem(closeicon, filename) self.fileListWidget.addItem(item) - print('dirPath in importDirImages is', dirpath) + print('DirPath in importDirImages is', dirpath) self.iconlist.clear() self.additems5(dirpath) self.changeFileFolder = True @@ -1459,7 +1459,6 @@ class MainWindow(QMainWindow, WindowMixin): self.reRecogButton.setEnabled(True) self.actions.AutoRec.setEnabled(True) self.actions.reRec.setEnabled(True) - self.actions.saveLabel.setEnabled(True) def openPrevImg(self, _value=False): @@ -1862,6 +1861,8 @@ class MainWindow(QMainWindow, WindowMixin): for each in states: file, state = each.split('\t') self.fileStatedict[file] = 1 + self.actions.saveLabel.setEnabled(True) + self.actions.saveRec.setEnabled(True) def saveFilestate(self): @@ -1919,22 +1920,29 @@ class MainWindow(QMainWindow, WindowMixin): rec_gt_dir = os.path.dirname(self.PPlabelpath) + '/rec_gt.txt' crop_img_dir = os.path.dirname(self.PPlabelpath) + '/crop_img/' + ques_img = [] if not os.path.exists(crop_img_dir): os.mkdir(crop_img_dir) with open(rec_gt_dir, 'w', encoding='utf-8') as f: for key in self.fileStatedict: idx = self.getImglabelidx(key) - for i, label in enumerate(self.PPlabel[idx]): - if label['difficult']: continue + try: img = cv2.imread(key) - img_crop = get_rotate_crop_image(img, np.array(label['points'], np.float32)) - img_name = os.path.splitext(os.path.basename(idx))[0] + '_crop_'+str(i)+'.jpg' - cv2.imwrite(crop_img_dir+img_name, img_crop) - f.write('crop_img/'+ img_name + '\t') - f.write(label['transcription'] + '\n') - - QMessageBox.information(self, "Information", "Cropped images has been saved in "+str(crop_img_dir)) + for i, label in enumerate(self.PPlabel[idx]): + if label['difficult']: continue + img_crop = get_rotate_crop_image(img, np.array(label['points'], np.float32)) + img_name = os.path.splitext(os.path.basename(idx))[0] + '_crop_'+str(i)+'.jpg' + cv2.imwrite(crop_img_dir+img_name, img_crop) + f.write('crop_img/'+ img_name + '\t') + f.write(label['transcription'] + '\n') + except Exception as e: + ques_img.append(key) + print("Can not read image ",e) + if ques_img: + QMessageBox.information(self, "Information", "The following images can not be saved, " + "please check the image path and labels.\n" + "".join(str(i)+'\n' for i in ques_img)) + QMessageBox.information(self, "Information", "Cropped images have been saved in "+str(crop_img_dir)) def speedChoose(self): if self.labelDialogOption.isChecked(): From b3a451da2672a4ca9f0825f8c720057a91fe35f6 Mon Sep 17 00:00:00 2001 From: Leif <4603009@qq.com> Date: Fri, 29 Jan 2021 15:03:41 +0800 Subject: [PATCH 07/10] Fix a spelling mistake --- ppocr/data/lmdb_dataset.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ppocr/data/lmdb_dataset.py b/ppocr/data/lmdb_dataset.py index bd0630f6..e2d6dc93 100644 --- a/ppocr/data/lmdb_dataset.py +++ b/ppocr/data/lmdb_dataset.py @@ -20,9 +20,9 @@ import cv2 from .imaug import transform, create_operators -class LMDBDateSet(Dataset): +class LMDBDataSet(Dataset): def __init__(self, config, mode, logger, seed=None): - super(LMDBDateSet, self).__init__() + super(LMDBDataSet, self).__init__() global_config = config['Global'] dataset_config = config[mode]['dataset'] From 42fe741ff18381df2fc00b665f0b4585ab065fd7 Mon Sep 17 00:00:00 2001 From: tink2123 Date: Fri, 29 Jan 2021 15:08:58 +0800 Subject: [PATCH 08/10] add srn doc --- doc/doc_ch/algorithm_overview.md | 4 ++-- doc/doc_ch/inference.md | 21 +++++++++++++++++---- doc/doc_ch/recognition.md | 2 ++ doc/doc_en/algorithm_overview_en.md | 4 ++-- doc/doc_en/inference_en.md | 20 ++++++++++++++++++-- doc/doc_en/recognition_en.md | 1 + 6 files changed, 42 insertions(+), 10 deletions(-) diff --git a/doc/doc_ch/algorithm_overview.md b/doc/doc_ch/algorithm_overview.md index 59d1bc8c..f0765695 100755 --- a/doc/doc_ch/algorithm_overview.md +++ b/doc/doc_ch/algorithm_overview.md @@ -41,7 +41,7 @@ PaddleOCR基于动态图开源的文本识别算法列表: - [x] Rosetta([paper](https://arxiv.org/abs/1910.05085))[10] - [x] STAR-Net([paper](http://www.bmva.org/bmvc/2016/papers/paper043/index.html))[11] - [ ] RARE([paper](https://arxiv.org/abs/1603.03915v1))[12] coming soon -- [ ] SRN([paper](https://arxiv.org/abs/2003.12294))[5] coming soon +- [x] SRN([paper](https://arxiv.org/abs/2003.12294))[5] 参考[DTRB][3](https://arxiv.org/abs/1904.01906)文字识别训练和评估流程,使用MJSynth和SynthText两个文字识别数据集训练,在IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE数据集上进行评估,算法效果如下: @@ -53,5 +53,5 @@ PaddleOCR基于动态图开源的文本识别算法列表: |CRNN|MobileNetV3|79.97%|rec_mv3_none_bilstm_ctc|[下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mv3_none_bilstm_ctc_v2.0_train.tar)| |StarNet|Resnet34_vd|84.44%|rec_r34_vd_tps_bilstm_ctc|[下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r34_vd_tps_bilstm_ctc_v2.0_train.tar)| |StarNet|MobileNetV3|81.42%|rec_mv3_tps_bilstm_ctc|[下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mv3_tps_bilstm_ctc_v2.0_train.tar)| - +|SRN|Resnet50_vd_fpn| 88.52% | rec_r50fpn_vd_none_srn | [下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r50_vd_srn_train.tar) | PaddleOCR文本识别算法的训练和使用请参考文档教程中[模型训练/评估中的文本识别部分](./recognition.md)。 diff --git a/doc/doc_ch/inference.md b/doc/doc_ch/inference.md index c4601e15..0daddd9b 100755 --- a/doc/doc_ch/inference.md +++ b/doc/doc_ch/inference.md @@ -22,8 +22,9 @@ inference 模型(`paddle.jit.save`保存的模型) - [三、文本识别模型推理](#文本识别模型推理) - [1. 超轻量中文识别模型推理](#超轻量中文识别模型推理) - [2. 基于CTC损失的识别模型推理](#基于CTC损失的识别模型推理) - - [3. 自定义文本识别字典的推理](#自定义文本识别字典的推理) - - [4. 多语言模型的推理](#多语言模型的推理) + - [3. 基于SRN损失的识别模型推理](#基于SRN损失的识别模型推理) + - [4. 自定义文本识别字典的推理](#自定义文本识别字典的推理) + - [5. 多语言模型的推理](#多语言模型的推理) - [四、方向分类模型推理](#方向识别模型推理) - [1. 方向分类模型推理](#方向分类模型推理) @@ -295,8 +296,20 @@ Predicts of ./doc/imgs_words_en/word_336.png:('super', 0.9999073) self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz" dict_character = list(self.character_str) ``` + +### 3. 基于SRN损失的识别模型推理 +基于SRN损失的识别模型,需要额外设置识别算法参数 --rec_algorithm="SRN"。 +同时需要保证预测shape与训练时一致,如: --rec_image_shape="1, 64, 256" -### 3. 自定义文本识别字典的推理 +``` +python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words_en/word_336.png" \ + --rec_model_dir="./inference/srn/" \ + --rec_image_shape="1, 64, 256" \ + --rec_char_type="en" \ + --rec_algorithm="SRN" +``` + +### 4. 自定义文本识别字典的推理 如果训练时修改了文本的字典,在使用inference模型预测时,需要通过`--rec_char_dict_path`指定使用的字典路径,并且设置 `rec_char_type=ch` ``` @@ -304,7 +317,7 @@ python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words_en/word_336.png ``` -### 4. 多语言模型的推理 +### 5. 多语言模型的推理 如果您需要预测的是其他语言模型,在使用inference模型预测时,需要通过`--rec_char_dict_path`指定使用的字典路径, 同时为了得到正确的可视化结果, 需要通过 `--vis_font_path` 指定可视化的字体路径,`doc/fonts/` 路径下有默认提供的小语种字体,例如韩文识别: diff --git a/doc/doc_ch/recognition.md b/doc/doc_ch/recognition.md index c5f459bd..bc877ab7 100644 --- a/doc/doc_ch/recognition.md +++ b/doc/doc_ch/recognition.md @@ -36,6 +36,7 @@ ln -sf /train_data/dataset * 数据下载 若您本地没有数据集,可以在官网下载 [icdar2015](http://rrc.cvc.uab.es/?ch=4&com=downloads) 数据,用于快速验证。也可以参考[DTRB](https://github.com/clovaai/deep-text-recognition-benchmark#download-lmdb-dataset-for-traininig-and-evaluation-from-here),下载 benchmark 所需的lmdb格式数据集。 +如果希望复现SRN的论文指标,需要下载离线[增广数据](https://pan.baidu.com/s/1-HSZ-ZVdqBF2HaBZ5pRAKA),提取码: y3ry。增广数据是由MJSynth和SynthText做旋转和扰动得到的。数据下载完成后请解压到 {your_path}/PaddleOCR/train_data/data_lmdb_release/training/ 路径下。 * 使用自己数据集 @@ -200,6 +201,7 @@ PaddleOCR支持训练和评估交替进行, 可以在 `configs/rec/rec_icdar15_t | rec_mv3_none_none_ctc.yml | Rosetta | Mobilenet_v3 large 0.5 | None | None | ctc | | rec_r34_vd_none_bilstm_ctc.yml | CRNN | Resnet34_vd | None | BiLSTM | ctc | | rec_r34_vd_none_none_ctc.yml | Rosetta | Resnet34_vd | None | None | ctc | +| rec_r50fpn_vd_none_srn.yml | SRN | Resnet50_fpn_vd | None | rnn | srn | 训练中文数据,推荐使用[rec_chinese_lite_train_v2.0.yml](../../configs/rec/ch_ppocr_v2.0/rec_chinese_lite_train_v2.0.yml),如您希望尝试其他算法在中文数据集上的效果,请参考下列说明修改配置文件: diff --git a/doc/doc_en/algorithm_overview_en.md b/doc/doc_en/algorithm_overview_en.md index 68bfd529..5016223f 100755 --- a/doc/doc_en/algorithm_overview_en.md +++ b/doc/doc_en/algorithm_overview_en.md @@ -43,7 +43,7 @@ PaddleOCR open-source text recognition algorithms list: - [x] Rosetta([paper](https://arxiv.org/abs/1910.05085))[10] - [x] STAR-Net([paper](http://www.bmva.org/bmvc/2016/papers/paper043/index.html))[11] - [ ] RARE([paper](https://arxiv.org/abs/1603.03915v1))[12] coming soon -- [ ] SRN([paper](https://arxiv.org/abs/2003.12294))[5] coming soon +- [x] SRN([paper](https://arxiv.org/abs/2003.12294))[5] Refer to [DTRB](https://arxiv.org/abs/1904.01906), the training and evaluation result of these above text recognition (using MJSynth and SynthText for training, evaluate on IIIT, SVT, IC03, IC13, IC15, SVTP, CUTE) is as follow: @@ -55,5 +55,5 @@ Refer to [DTRB](https://arxiv.org/abs/1904.01906), the training and evaluation r |CRNN|MobileNetV3|79.97%|rec_mv3_none_bilstm_ctc|[Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mv3_none_bilstm_ctc_v2.0_train.tar)| |StarNet|Resnet34_vd|84.44%|rec_r34_vd_tps_bilstm_ctc|[Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r34_vd_tps_bilstm_ctc_v2.0_train.tar)| |StarNet|MobileNetV3|81.42%|rec_mv3_tps_bilstm_ctc|[Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mv3_tps_bilstm_ctc_v2.0_train.tar)| - +|SRN|Resnet50_vd_fpn| 88.52% | rec_r50fpn_vd_none_srn |[Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r50_vd_srn_train.tar)| Please refer to the document for training guide and use of PaddleOCR text recognition algorithms [Text recognition model training/evaluation/prediction](./recognition_en.md) diff --git a/doc/doc_en/inference_en.md b/doc/doc_en/inference_en.md index ccbb7184..c8ce1424 100755 --- a/doc/doc_en/inference_en.md +++ b/doc/doc_en/inference_en.md @@ -25,6 +25,7 @@ Next, we first introduce how to convert a trained model into an inference model, - [TEXT RECOGNITION MODEL INFERENCE](#RECOGNITION_MODEL_INFERENCE) - [1. LIGHTWEIGHT CHINESE MODEL](#LIGHTWEIGHT_RECOGNITION) - [2. CTC-BASED TEXT RECOGNITION MODEL INFERENCE](#CTC-BASED_RECOGNITION) + - [3. SRN-BASED TEXT RECOGNITION MODEL INFERENCE](#SRN-BASED_RECOGNITION) - [3. TEXT RECOGNITION MODEL INFERENCE USING CUSTOM CHARACTERS DICTIONARY](#USING_CUSTOM_CHARACTERS) - [4. MULTILINGUAL MODEL INFERENCE](MULTILINGUAL_MODEL_INFERENCE) @@ -304,8 +305,23 @@ self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz" dict_character = list(self.character_str) ``` + +### 3. SRN-BASED TEXT RECOGNITION MODEL INFERENCE + +The recognition model based on SRN requires additional setting of the recognition algorithm parameter +--rec_algorithm="SRN". At the same time, it is necessary to ensure that the predicted shape is consistent +with the training, such as: --rec_image_shape="1, 64, 256" + +``` +python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words_en/word_336.png" \ + --rec_model_dir="./inference/srn/" \ + --rec_image_shape="1, 64, 256" \ + --rec_char_type="en" \ + --rec_algorithm="SRN" +``` + -### 3. TEXT RECOGNITION MODEL INFERENCE USING CUSTOM CHARACTERS DICTIONARY +### 4. TEXT RECOGNITION MODEL INFERENCE USING CUSTOM CHARACTERS DICTIONARY If the text dictionary is modified during training, when using the inference model to predict, you need to specify the dictionary path used by `--rec_char_dict_path`, and set `rec_char_type=ch` ``` @@ -313,7 +329,7 @@ python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words_en/word_336.png ``` -### 4. MULTILINGAUL MODEL INFERENCE +### 5. MULTILINGAUL MODEL INFERENCE If you need to predict other language models, when using inference model prediction, you need to specify the dictionary path used by `--rec_char_dict_path`. At the same time, in order to get the correct visualization results, You need to specify the visual font path through `--vis_font_path`. There are small language fonts provided by default under the `doc/fonts` path, such as Korean recognition: diff --git a/doc/doc_en/recognition_en.md b/doc/doc_en/recognition_en.md index 22f89cde..f29703d1 100644 --- a/doc/doc_en/recognition_en.md +++ b/doc/doc_en/recognition_en.md @@ -195,6 +195,7 @@ If the evaluation set is large, the test will be time-consuming. It is recommend | rec_mv3_none_none_ctc.yml | Rosetta | Mobilenet_v3 large 0.5 | None | None | ctc | | rec_r34_vd_none_bilstm_ctc.yml | CRNN | Resnet34_vd | None | BiLSTM | ctc | | rec_r34_vd_none_none_ctc.yml | Rosetta | Resnet34_vd | None | None | ctc | +| rec_r50fpn_vd_none_srn.yml | SRN | Resnet50_fpn_vd | None | rnn | srn | For training Chinese data, it is recommended to use [rec_chinese_lite_train_v2.0.yml](../../configs/rec/ch_ppocr_v2.0/rec_chinese_lite_train_v2.0.yml). If you want to try the result of other algorithms on the Chinese data set, please refer to the following instructions to modify the configuration file: From 6781d55df4a705b1d0d7201e5fc6b484d4912a9b Mon Sep 17 00:00:00 2001 From: tink2123 Date: Fri, 29 Jan 2021 15:23:11 +0800 Subject: [PATCH 09/10] format doc --- doc/doc_ch/algorithm_overview.md | 1 + doc/doc_en/algorithm_overview_en.md | 1 + 2 files changed, 2 insertions(+) diff --git a/doc/doc_ch/algorithm_overview.md b/doc/doc_ch/algorithm_overview.md index f0765695..abbc5da4 100755 --- a/doc/doc_ch/algorithm_overview.md +++ b/doc/doc_ch/algorithm_overview.md @@ -54,4 +54,5 @@ PaddleOCR基于动态图开源的文本识别算法列表: |StarNet|Resnet34_vd|84.44%|rec_r34_vd_tps_bilstm_ctc|[下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r34_vd_tps_bilstm_ctc_v2.0_train.tar)| |StarNet|MobileNetV3|81.42%|rec_mv3_tps_bilstm_ctc|[下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mv3_tps_bilstm_ctc_v2.0_train.tar)| |SRN|Resnet50_vd_fpn| 88.52% | rec_r50fpn_vd_none_srn | [下载链接](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r50_vd_srn_train.tar) | + PaddleOCR文本识别算法的训练和使用请参考文档教程中[模型训练/评估中的文本识别部分](./recognition.md)。 diff --git a/doc/doc_en/algorithm_overview_en.md b/doc/doc_en/algorithm_overview_en.md index 5016223f..7d7896e7 100755 --- a/doc/doc_en/algorithm_overview_en.md +++ b/doc/doc_en/algorithm_overview_en.md @@ -56,4 +56,5 @@ Refer to [DTRB](https://arxiv.org/abs/1904.01906), the training and evaluation r |StarNet|Resnet34_vd|84.44%|rec_r34_vd_tps_bilstm_ctc|[Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r34_vd_tps_bilstm_ctc_v2.0_train.tar)| |StarNet|MobileNetV3|81.42%|rec_mv3_tps_bilstm_ctc|[Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_mv3_tps_bilstm_ctc_v2.0_train.tar)| |SRN|Resnet50_vd_fpn| 88.52% | rec_r50fpn_vd_none_srn |[Download link](https://paddleocr.bj.bcebos.com/dygraph_v2.0/en/rec_r50_vd_srn_train.tar)| + Please refer to the document for training guide and use of PaddleOCR text recognition algorithms [Text recognition model training/evaluation/prediction](./recognition_en.md) From 2a0c3d4dac67cfd49e432303443bb9a50e75071f Mon Sep 17 00:00:00 2001 From: xiaoting <31891223+tink2123@users.noreply.github.com> Date: Sun, 31 Jan 2021 22:37:30 +0800 Subject: [PATCH 10/10] fix eval mode without srn (#1889) * fix base model * fix start time --- tools/program.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/tools/program.py b/tools/program.py index 694d6415..f3ba4945 100755 --- a/tools/program.py +++ b/tools/program.py @@ -326,9 +326,12 @@ def eval(model, valid_dataloader, post_process_class, eval_class): if idx >= len(valid_dataloader): break images = batch[0] - others = batch[-4:] start = time.time() - preds = model(images, others) + if "SRN" in str(model.head): + others = batch[-4:] + preds = model(images, others) + else: + preds = model(images) batch = [item.numpy() for item in batch] # Obtain usable results from post-processing methods