diff --git a/PPOCRLabel/PPOCRLabel.py b/PPOCRLabel/PPOCRLabel.py
index 92d80c8a..8babac65 100644
--- a/PPOCRLabel/PPOCRLabel.py
+++ b/PPOCRLabel/PPOCRLabel.py
@@ -1031,7 +1031,7 @@ class MainWindow(QMainWindow, WindowMixin):
for box in self.result_dic:
trans_dic = {"label": box[1][0], "points": box[0], 'difficult': False}
- if trans_dic["label"] is "" and mode == 'Auto':
+ if trans_dic["label"] == "" and mode == 'Auto':
continue
shapes.append(trans_dic)
@@ -1763,7 +1763,7 @@ class MainWindow(QMainWindow, WindowMixin):
QMessageBox.information(self, "Information", msg)
return
result = self.ocr.ocr(img_crop, cls=True, det=False)
- if result[0][0] is not '':
+ if result[0][0] != '':
result.insert(0, box)
print('result in reRec is ', result)
self.result_dic.append(result)
@@ -1794,7 +1794,7 @@ class MainWindow(QMainWindow, WindowMixin):
QMessageBox.information(self, "Information", msg)
return
result = self.ocr.ocr(img_crop, cls=True, det=False)
- if result[0][0] is not '':
+ if result[0][0] != '':
result.insert(0, box)
print('result in reRec is ', result)
if result[1][0] == shape.label:
@@ -1999,7 +1999,7 @@ if __name__ == '__main__':
resource_file = './libs/resources.py'
if not os.path.exists(resource_file):
output = os.system('pyrcc5 -o libs/resources.py resources.qrc')
- assert output is 0, "operate the cmd have some problems ,please check whether there is a in the lib " \
+ assert output == 0, "operate the cmd have some problems ,please check whether there is a in the lib " \
"directory resources.py "
import libs.resources
sys.exit(main())
diff --git a/README.md b/README.md
index 5b6e4bd0..67d65e98 100644
--- a/README.md
+++ b/README.md
@@ -5,10 +5,11 @@ PaddleOCR aims to create multilingual, awesome, leading, and practical OCR tools
## Notice
PaddleOCR supports both dynamic graph and static graph programming paradigm
-- Dynamic graph: dygraph branch (default), **supported by paddle 2.0rc1+ ([installation](./doc/doc_en/installation_en.md))**
+- Dynamic graph: dygraph branch (default), **supported by paddle 2.0.0 ([installation](./doc/doc_en/installation_en.md))**
- Static graph: develop branch
**Recent updates**
+- 2021.1.21 update more than 25+ multilingual recognition models [models list](./doc/doc_en/models_list_en.md), including:English, Chinese, German, French, Japanese,Spanish,Portuguese Russia Arabic and so on. Models for more languages will continue to be updated [Develop Plan](https://github.com/PaddlePaddle/PaddleOCR/issues/1048).
- 2020.12.15 update Data synthesis tool, i.e., [Style-Text](./StyleText/README.md),easy to synthesize a large number of images which are similar to the target scene image.
- 2020.11.25 Update a new data annotation tool, i.e., [PPOCRLabel](./PPOCRLabel/README.md), which is helpful to improve the labeling efficiency. Moreover, the labeling results can be used in training of the PP-OCR system directly.
- 2020.9.22 Update the PP-OCR technical article, https://arxiv.org/abs/2009.09941
diff --git a/README_ch.md b/README_ch.md
index 2de6fdf5..d627ec45 100755
--- a/README_ch.md
+++ b/README_ch.md
@@ -4,11 +4,13 @@
PaddleOCR旨在打造一套丰富、领先、且实用的OCR工具库,助力使用者训练出更好的模型,并应用落地。
## 注意
PaddleOCR同时支持动态图与静态图两种编程范式
-- 动态图版本:dygraph分支(默认),需将paddle版本升级至2.0rc1+([快速安装](./doc/doc_ch/installation.md))
+- 动态图版本:dygraph分支(默认),需将paddle版本升级至2.0.0([快速安装](./doc/doc_ch/installation.md))
- 静态图版本:develop分支
**近期更新**
-- 2021.1.18 [FAQ](./doc/doc_ch/FAQ.md)新增5个高频问题,总数152个,每周一都会更新,欢迎大家持续关注。
+- 2021.1.26,28,29 PaddleOCR官方研发团队带来技术深入解读三日直播课,1月26日、28日、29日晚上19:30,[直播地址](https://live.bilibili.com/21689802)
+- 2021.1.25 [FAQ](./doc/doc_ch/FAQ.md)新增5个高频问题,总数157个,每周一都会更新,欢迎大家持续关注。
+- 2021.1.21 更新多语言识别模型,目前支持语种超过27种,[多语言模型下载](./doc/doc_ch/models_list.md),包括中文简体、中文繁体、英文、法文、德文、韩文、日文、意大利文、西班牙文、葡萄牙文、俄罗斯文、阿拉伯文等,后续计划可以参考[多语言研发计划](https://github.com/PaddlePaddle/PaddleOCR/issues/1048)
- 2020.12.15 更新数据合成工具[Style-Text](./StyleText/README_ch.md),可以批量合成大量与目标场景类似的图像,在多个场景验证,效果明显提升。
- 2020.11.25 更新半自动标注工具[PPOCRLabel](./PPOCRLabel/README_ch.md),辅助开发者高效完成标注任务,输出格式与PP-OCR训练任务完美衔接。
- 2020.9.22 更新PP-OCR技术文章,https://arxiv.org/abs/2009.09941
diff --git a/StyleText/README.md b/StyleText/README.md
index df4fbf3c..65a72ac8 100644
--- a/StyleText/README.md
+++ b/StyleText/README.md
@@ -72,7 +72,7 @@ fusion_generator:
python3 tools/synth_image.py -c configs/config.yml --style_image examples/style_images/2.jpg --text_corpus PaddleOCR --language en
```
-* Note 1: The language options is correspond to the corpus. Currently, the tool only supports English, Simplified Chinese and Korean.
+* Note 1: The language options is correspond to the corpus. Currently, the tool only supports English(en), Simplified Chinese(ch) and Korean(ko).
* Note 2: Synth-Text is mainly used to generate images for OCR recognition models.
So the height of style images should be around 32 pixels. Images in other sizes may behave poorly.
* Note 3: You can modify `use_gpu` in `configs/config.yml` to determine whether to use GPU for prediction.
@@ -120,7 +120,7 @@ In actual application scenarios, it is often necessary to synthesize pictures in
* `with_label`:Whether the `label_file` is label file list.
* `CorpusGenerator`:
* `method`:Method of CorpusGenerator,supports `FileCorpus` and `EnNumCorpus`. If `EnNumCorpus` is used,No other configuration is needed,otherwise you need to set `corpus_file` and `language`.
- * `language`:Language of the corpus.
+ * `language`:Language of the corpus. Currently, the tool only supports English(en), Simplified Chinese(ch) and Korean(ko).
* `corpus_file`: Filepath of the corpus. Corpus file should be a text file which will be split by line-endings('\n'). Corpus generator samples one line each time.
diff --git a/StyleText/README_ch.md b/StyleText/README_ch.md
index fd259ca0..ccd1efaf 100644
--- a/StyleText/README_ch.md
+++ b/StyleText/README_ch.md
@@ -63,10 +63,10 @@ fusion_generator:
```python
python3 tools/synth_image.py -c configs/config.yml --style_image examples/style_images/2.jpg --text_corpus PaddleOCR --language en
```
-* 注1:语言选项和语料相对应,目前该工具只支持英文、简体中文和韩语。
+* 注1:语言选项和语料相对应,目前支持英文(en)、简体中文(ch)和韩语(ko)。
* 注2:Style-Text生成的数据主要应用于OCR识别场景。基于当前PaddleOCR识别模型的设计,我们主要支持高度在32左右的风格图像。
如果输入图像尺寸相差过多,效果可能不佳。
-* 注3:可以通过修改配置文件中的`use_gpu`(true或者false)参数来决定是否使用GPU进行预测。
+* 注3:可以通过修改配置文件`configs/config.yml`中的`use_gpu`(true或者false)参数来决定是否使用GPU进行预测。
例如,输入如下图片和语料"PaddleOCR":
@@ -105,7 +105,7 @@ python3 tools/synth_image.py -c configs/config.yml --style_image examples/style_
* `with_label`:标志`label_file`是否为label文件。
* `CorpusGenerator`:
* `method`:语料生成方法,目前有`FileCorpus`和`EnNumCorpus`可选。如果使用`EnNumCorpus`,则不需要填写其他配置,否则需要修改`corpus_file`和`language`;
- * `language`:语料的语种;
+ * `language`:语料的语种,目前支持英文(en)、简体中文(ch)和韩语(ko);
* `corpus_file`: 语料文件路径。语料文件应使用文本文件。语料生成器首先会将语料按行切分,之后每次随机选取一行。
语料文件格式示例:
diff --git a/configs/rec/multi_language/rec_en_number_lite_train.yml b/configs/rec/multi_language/rec_en_number_lite_train.yml
index cee05121..13eda848 100644
--- a/configs/rec/multi_language/rec_en_number_lite_train.yml
+++ b/configs/rec/multi_language/rec_en_number_lite_train.yml
@@ -16,7 +16,7 @@ Global:
infer_img:
# for data or label process
character_dict_path: ppocr/utils/dict/en_dict.txt
- character_type: ch
+ character_type: EN
max_text_length: 25
infer_mode: False
use_space_char: False
diff --git a/configs/rec/rec_mv3_none_bilstm_ctc.yml b/configs/rec/rec_mv3_none_bilstm_ctc.yml
index 38f1e869..00c1db88 100644
--- a/configs/rec/rec_mv3_none_bilstm_ctc.yml
+++ b/configs/rec/rec_mv3_none_bilstm_ctc.yml
@@ -1,5 +1,5 @@
Global:
- use_gpu: true
+ use_gpu: True
epoch_num: 72
log_smooth_window: 20
print_batch_step: 10
@@ -59,7 +59,7 @@ Metric:
Train:
dataset:
- name: LMDBDateSet
+ name: LMDBDataSet
data_dir: ./train_data/data_lmdb_release/training/
transforms:
- DecodeImage: # load image
@@ -78,7 +78,7 @@ Train:
Eval:
dataset:
- name: LMDBDateSet
+ name: LMDBDataSet
data_dir: ./train_data/data_lmdb_release/validation/
transforms:
- DecodeImage: # load image
diff --git a/configs/rec/rec_mv3_none_none_ctc.yml b/configs/rec/rec_mv3_none_none_ctc.yml
index 33079ad4..6711b1d2 100644
--- a/configs/rec/rec_mv3_none_none_ctc.yml
+++ b/configs/rec/rec_mv3_none_none_ctc.yml
@@ -58,7 +58,7 @@ Metric:
Train:
dataset:
- name: LMDBDateSet
+ name: LMDBDataSet
data_dir: ./train_data/data_lmdb_release/training/
transforms:
- DecodeImage: # load image
@@ -77,7 +77,7 @@ Train:
Eval:
dataset:
- name: LMDBDateSet
+ name: LMDBDataSet
data_dir: ./train_data/data_lmdb_release/validation/
transforms:
- DecodeImage: # load image
diff --git a/configs/rec/rec_mv3_tps_bilstm_ctc.yml b/configs/rec/rec_mv3_tps_bilstm_ctc.yml
index 08f68939..1b9fb0a0 100644
--- a/configs/rec/rec_mv3_tps_bilstm_ctc.yml
+++ b/configs/rec/rec_mv3_tps_bilstm_ctc.yml
@@ -63,7 +63,7 @@ Metric:
Train:
dataset:
- name: LMDBDateSet
+ name: LMDBDataSet
data_dir: ./train_data/data_lmdb_release/training/
transforms:
- DecodeImage: # load image
@@ -82,7 +82,7 @@ Train:
Eval:
dataset:
- name: LMDBDateSet
+ name: LMDBDataSet
data_dir: ./train_data/data_lmdb_release/validation/
transforms:
- DecodeImage: # load image
diff --git a/configs/rec/rec_r34_vd_none_bilstm_ctc.yml b/configs/rec/rec_r34_vd_none_bilstm_ctc.yml
index 4ad2ff89..e4d301a6 100644
--- a/configs/rec/rec_r34_vd_none_bilstm_ctc.yml
+++ b/configs/rec/rec_r34_vd_none_bilstm_ctc.yml
@@ -58,7 +58,7 @@ Metric:
Train:
dataset:
- name: LMDBDateSet
+ name: LMDBDataSet
data_dir: ./train_data/data_lmdb_release/training/
transforms:
- DecodeImage: # load image
@@ -77,7 +77,7 @@ Train:
Eval:
dataset:
- name: LMDBDateSet
+ name: LMDBDataSet
data_dir: ./train_data/data_lmdb_release/validation/
transforms:
- DecodeImage: # load image
diff --git a/configs/rec/rec_r34_vd_none_none_ctc.yml b/configs/rec/rec_r34_vd_none_none_ctc.yml
index 9c1eeb30..4a17a004 100644
--- a/configs/rec/rec_r34_vd_none_none_ctc.yml
+++ b/configs/rec/rec_r34_vd_none_none_ctc.yml
@@ -56,7 +56,7 @@ Metric:
Train:
dataset:
- name: LMDBDateSet
+ name: LMDBDataSet
data_dir: ./train_data/data_lmdb_release/training/
transforms:
- DecodeImage: # load image
@@ -75,7 +75,7 @@ Train:
Eval:
dataset:
- name: LMDBDateSet
+ name: LMDBDataSet
data_dir: ./train_data/data_lmdb_release/validation/
transforms:
- DecodeImage: # load image
diff --git a/configs/rec/rec_r34_vd_tps_bilstm_ctc.yml b/configs/rec/rec_r34_vd_tps_bilstm_ctc.yml
index aeded492..62edf843 100644
--- a/configs/rec/rec_r34_vd_tps_bilstm_ctc.yml
+++ b/configs/rec/rec_r34_vd_tps_bilstm_ctc.yml
@@ -62,7 +62,7 @@ Metric:
Train:
dataset:
- name: LMDBDateSet
+ name: LMDBDataSet
data_dir: ./train_data/data_lmdb_release/training/
transforms:
- DecodeImage: # load image
@@ -81,7 +81,7 @@ Train:
Eval:
dataset:
- name: LMDBDateSet
+ name: LMDBDataSet
data_dir: ./train_data/data_lmdb_release/validation/
transforms:
- DecodeImage: # load image
diff --git a/configs/rec/rec_r50_fpn_srn.yml b/configs/rec/rec_r50_fpn_srn.yml
new file mode 100644
index 00000000..ec7f1705
--- /dev/null
+++ b/configs/rec/rec_r50_fpn_srn.yml
@@ -0,0 +1,107 @@
+Global:
+ use_gpu: True
+ epoch_num: 72
+ log_smooth_window: 20
+ print_batch_step: 5
+ save_model_dir: ./output/rec/srn_new
+ save_epoch_step: 3
+ # evaluation is run every 5000 iterations after the 4000th iteration
+ eval_batch_step: [0, 5000]
+ # if pretrained_model is saved in static mode, load_static_weights must set to True
+ cal_metric_during_train: True
+ pretrained_model:
+ checkpoints:
+ save_inference_dir:
+ use_visualdl: False
+ infer_img: doc/imgs_words/ch/word_1.jpg
+ # for data or label process
+ character_dict_path:
+ character_type: en
+ max_text_length: 25
+ num_heads: 8
+ infer_mode: False
+ use_space_char: False
+
+
+Optimizer:
+ name: Adam
+ beta1: 0.9
+ beta2: 0.999
+ clip_norm: 10.0
+ lr:
+ learning_rate: 0.0001
+
+Architecture:
+ model_type: rec
+ algorithm: SRN
+ in_channels: 1
+ Transform:
+ Backbone:
+ name: ResNetFPN
+ Head:
+ name: SRNHead
+ max_text_length: 25
+ num_heads: 8
+ num_encoder_TUs: 2
+ num_decoder_TUs: 4
+ hidden_dims: 512
+
+Loss:
+ name: SRNLoss
+
+PostProcess:
+ name: SRNLabelDecode
+
+Metric:
+ name: RecMetric
+ main_indicator: acc
+
+Train:
+ dataset:
+ name: LMDBDataSet
+ data_dir: ./train_data/srn_train_data_duiqi
+ transforms:
+ - DecodeImage: # load image
+ img_mode: BGR
+ channel_first: False
+ - SRNLabelEncode: # Class handling label
+ - SRNRecResizeImg:
+ image_shape: [1, 64, 256]
+ - KeepKeys:
+ keep_keys: ['image',
+ 'label',
+ 'length',
+ 'encoder_word_pos',
+ 'gsrm_word_pos',
+ 'gsrm_slf_attn_bias1',
+ 'gsrm_slf_attn_bias2'] # dataloader will return list in this order
+ loader:
+ shuffle: False
+ batch_size_per_card: 64
+ drop_last: False
+ num_workers: 4
+
+Eval:
+ dataset:
+ name: LMDBDataSet
+ data_dir: ./train_data/data_lmdb_release/evaluation
+ transforms:
+ - DecodeImage: # load image
+ img_mode: BGR
+ channel_first: False
+ - SRNLabelEncode: # Class handling label
+ - SRNRecResizeImg:
+ image_shape: [1, 64, 256]
+ - KeepKeys:
+ keep_keys: ['image',
+ 'label',
+ 'length',
+ 'encoder_word_pos',
+ 'gsrm_word_pos',
+ 'gsrm_slf_attn_bias1',
+ 'gsrm_slf_attn_bias2']
+ loader:
+ shuffle: False
+ drop_last: False
+ batch_size_per_card: 32
+ num_workers: 4
diff --git a/deploy/slim/quantization/README.md b/deploy/slim/quantization/README.md
index ccd4d06b..4ac3f7c3 100644
--- a/deploy/slim/quantization/README.md
+++ b/deploy/slim/quantization/README.md
@@ -42,7 +42,7 @@ python deploy/slim/quantization/quant.py -c configs/det/det_mv3_db.yml -o Global
# 比如下载提供的训练模型
wget https://paddleocr.bj.bcebos.com/dygraph_v2.0/ch/ch_ppocr_mobile_v2.0_det_train.tar
tar -xf ch_ppocr_mobile_v2.0_det_train.tar
-python deploy/slim/quantization/quant.py -c configs/det/det_mv3_db.yml -o Global.pretrain_weights=./ch_ppocr_mobile_v2.0_det_train/best_accuracy Global.save_model_dir=./output/quant_model
+python deploy/slim/quantization/quant.py -c configs/det/det_mv3_db.yml -o Global.pretrain_weights=./ch_ppocr_mobile_v2.0_det_train/best_accuracy Global.save_inference_dir=./output/quant_inference_model
```
如果要训练识别模型的量化,修改配置文件和加载的模型参数即可。
diff --git a/deploy/slim/quantization/README_en.md b/deploy/slim/quantization/README_en.md
index 7da0b3e7..36407a2b 100644
--- a/deploy/slim/quantization/README_en.md
+++ b/deploy/slim/quantization/README_en.md
@@ -58,7 +58,7 @@ python deploy/slim/quantization/quant.py -c configs/det/det_mv3_db.yml -o Global
After getting the model after pruning and finetuning we, can export it as inference_model for predictive deployment:
```bash
-python deploy/slim/quantization/export_model.py -c configs/det/det_mv3_db.yml -o Global.checkpoints=output/quant_model/best_accuracy Global.save_model_dir=./output/quant_inference_model
+python deploy/slim/quantization/export_model.py -c configs/det/det_mv3_db.yml -o Global.checkpoints=output/quant_model/best_accuracy Global.save_inference_dir=./output/quant_inference_model
```
### 5. Deploy
diff --git a/doc/doc_ch/FAQ.md b/doc/doc_ch/FAQ.md
index 37b9834d..11a9f35d 100755
--- a/doc/doc_ch/FAQ.md
+++ b/doc/doc_ch/FAQ.md
@@ -9,42 +9,43 @@
## PaddleOCR常见问题汇总(持续更新)
-* [近期更新(2021.1.18)](#近期更新)
+* [近期更新(2021.1.25)](#近期更新)
* [【精选】OCR精选10个问题](#OCR精选10个问题)
* [【理论篇】OCR通用32个问题](#OCR通用问题)
* [基础知识7题](#基础知识)
* [数据集7题](#数据集2)
* [模型训练调优18题](#模型训练调优2)
-* [【实战篇】PaddleOCR实战110个问题](#PaddleOCR实战问题)
- * [使用咨询36题](#使用咨询)
+* [【实战篇】PaddleOCR实战115个问题](#PaddleOCR实战问题)
+ * [使用咨询38题](#使用咨询)
* [数据集17题](#数据集3)
* [模型训练调优28题](#模型训练调优3)
- * [预测部署29题](#预测部署3)
+ * [预测部署32题](#预测部署3)
-## 近期更新(2021.1.18)
+## 近期更新(2021.1.25)
+#### Q3.1.37: 小语种模型只有识别模型,没有检测模型吗?
-#### Q2.3.18: 在PP-OCR系统中,文本检测的骨干网络为什么没有使用SE模块?
+**A**:小语种(包括纯英文数字)的检测模型和中文的检测模型是共用的,在训练中文检测模型时加入了多语言数据。https://github.com/PaddlePaddle/PaddleOCR/blob/dygraph/doc/doc_en/models_list_en.md#1-text-detection-model。
-**A**:SE模块是MobileNetV3网络一个重要模块,目的是估计特征图每个特征通道重要性,给特征图每个特征分配权重,提高网络的表达能力。但是,对于文本检测,输入网络的分辨率比较大,一般是640\*640,利用SE模块估计特征图每个特征通道重要性比较困难,网络提升能力有限,但是该模块又比较耗时,因此在PP-OCR系统中,文本检测的骨干网络没有使用SE模块。实验也表明,当去掉SE模块,超轻量模型大小可以减小40%,文本检测效果基本不受影响。详细可以参考PP-OCR技术文章,https://arxiv.org/abs/2009.09941.
+#### Q3.1.38: module 'paddle.distributed' has no attribute ‘get_rank’。
-#### Q3.3.27: PaddleOCR关于文本识别模型的训练,支持的数据增强方式有哪些?
+**A**:Paddle版本问题,请安装2.0版本Paddle:pip install paddlepaddle==2.0.0。
-**A**:文本识别支持的数据增强方式有随机小幅度裁剪、图像平衡、添加白噪声、颜色漂移、图像反色和Text Image Augmentation(TIA)变换等。可以参考[代码](../../ppocr/data/imaug/rec_img_aug.py)中的warp函数。
+#### Q3.4.30: PaddleOCR是否支持在华为鲲鹏920CPU上部署?
-#### Q3.3.28: 关于dygraph分支中,文本识别模型训练,要使用数据增强应该如何设置?
+**A**:目前Paddle的预测库是支持华为鲲鹏920CPU的,但是OCR还没在这些芯片上测试过,可以自己调试,有问题反馈给我们。
-**A**:可以参考[配置文件](../../configs/rec/ch_ppocr_v2.0/rec_chinese_lite_train_v2.0.yml)在Train['dataset']['transforms']添加RecAug字段,使数据增强生效。可以通过添加对aug_prob设置,表示每种数据增强采用的概率。aug_prob默认是0.4.由于tia数据增强特殊性,默认不采用,可以通过添加use_tia设置,使tia数据增强生效。详细设置可以参考[ISSUE 1744](https://github.com/PaddlePaddle/PaddleOCR/issues/1744)。
+#### Q3.4.31: 采用Paddle-Lite进行端侧部署,出现问题,环境没问题。
-#### Q3.4.28: PP-OCR系统中,文本检测的结果有置信度吗?
+**A**:如果你的预测库是自己编译的,那么你的nb文件也要自己编译,用同一个lite版本。不能直接用下载的nb文件,因为版本不同。
-**A**:文本检测的结果有置信度,由于推理过程中没有使用,所以没有显示的返回到最终结果中。如果需要文本检测结果的置信度,可以在[文本检测DB的后处理代码](../../ppocr/postprocess/db_postprocess.py)的155行,添加scores信息。这样,在[检测预测代码](../../tools/infer/predict_det.py)的197行,就可以拿到文本检测的scores信息。
+#### Q3.4.32: PaddleOCR的模型支持onnx转换吗?
-#### Q3.4.29: DB文本检测,特征提取网络金字塔构建的部分代码在哪儿?
-
-**A**:特征提取网络金字塔构建的部分:[代码位置](../../ppocr/modeling/necks/db_fpn.py)。ppocr/modeling文件夹里面是组网相关的代码,其中architectures是文本检测或者文本识别整体流程代码;backbones是骨干网络相关代码;necks是类似与FPN的颈函数代码;heads是提取文本检测或者文本识别预测结果相关的头函数;transforms是类似于TPS特征预处理模块。更多的信息可以参考[代码组织结构](./tree.md)。
+**A**:我们目前已经通过Paddle2ONNX来支持各模型套件的转换,PaddleOCR基于PaddlePaddle 2.0的版本(dygraph分支)已经支持导出为ONNX,欢迎关注Paddle2ONNX,了解更多项目的进展:
+Paddle2ONNX项目:https://github.com/PaddlePaddle/Paddle2ONNX
+Paddle2ONNX支持转换的[模型列表](https://github.com/PaddlePaddle/Paddle2ONNX/blob/develop/docs/zh/model_zoo.md#%E5%9B%BE%E5%83%8Focr)
## 【精选】OCR精选10个问题
@@ -396,13 +397,13 @@
**A**:动态图版本正在紧锣密鼓开发中,将于2020年12月16日发布,敬请关注。
#### Q3.1.22:ModuleNotFoundError: No module named 'paddle.nn',
-**A**:paddle.nn是Paddle2.0版本特有的功能,请安装大于等于Paddle 2.0.0rc1的版本,安装方式为
+**A**:paddle.nn是Paddle2.0版本特有的功能,请安装大于等于Paddle 2.0.0的版本,安装方式为
```
-python3 -m pip install paddlepaddle-gpu==2.0.0rc1 -i https://mirror.baidu.com/pypi/simple
+python3 -m pip install paddlepaddle-gpu==2.0.0 -i https://mirror.baidu.com/pypi/simple
```
#### Q3.1.23: ImportError: /usr/lib/x86_64_linux-gnu/libstdc++.so.6:version `CXXABI_1.3.11` not found (required by /usr/lib/python3.6/site-package/paddle/fluid/core+avx.so)
-**A**:这个问题是glibc版本不足导致的,Paddle2.0rc1版本对gcc版本和glib版本有更高的要求,推荐gcc版本为8.2,glibc版本2.12以上。
+**A**:这个问题是glibc版本不足导致的,Paddle2.0.0版本对gcc版本和glib版本有更高的要求,推荐gcc版本为8.2,glibc版本2.12以上。
如果您的环境不满足这个要求,或者使用的docker镜像为:
`hub.baidubce.com/paddlepaddle/paddle:latest-gpu-cuda9.0-cudnn7-dev`
`hub.baidubce.com/paddlepaddle/paddle:latest-gpu-cuda9.0-cudnn7-dev`,安装Paddle2.0rc版本可能会出现上述错误,2.0版本推荐使用新的docker镜像 `paddlepaddle/paddle:latest-dev-cuda10.1-cudnn7-gcc82`。
@@ -414,7 +415,7 @@ python3 -m pip install paddlepaddle-gpu==2.0.0rc1 -i https://mirror.baidu.com/py
- develop:基于Paddle静态图开发的分支,推荐使用paddle1.8 或者2.0版本,该分支具备完善的模型训练、预测、推理部署、量化裁剪等功能,领先于release/1.1分支。
- release/1.1:PaddleOCR 发布的第一个稳定版本,基于静态图开发,具备完善的训练、预测、推理部署、量化裁剪等功能。
-- dygraph:基于Paddle动态图开发的分支,目前仍在开发中,未来将作为主要开发分支,运行要求使用Paddle2.0rc1版本,目前仍在开发中。
+- dygraph:基于Paddle动态图开发的分支,目前仍在开发中,未来将作为主要开发分支,运行要求使用Paddle2.0.0版本。
- release/2.0-rc1-0:PaddleOCR发布的第二个稳定版本,基于动态图和paddle2.0版本开发,动态图开发的工程更易于调试,目前支,支持模型训练、预测,暂不支持移动端部署。
如果您已经上手过PaddleOCR,并且希望在各种环境上部署PaddleOCR,目前建议使用静态图分支,develop或者release/1.1分支。如果您是初学者,想快速训练,调试PaddleOCR中的算法,建议尝鲜PaddleOCR dygraph分支。
@@ -431,7 +432,7 @@ python3 -m pip install paddlepaddle-gpu==2.0.0rc1 -i https://mirror.baidu.com/py
#### Q3.1.27: 如何可视化acc,loss曲线图,模型网络结构图等?
-**A**:在配置文件里有`use_visualdl`的参数,设置为True即可,更多的使用命令可以参考:[VisualDL使用指南](https://www.paddlepaddle.org.cn/documentation/docs/zh/2.0-rc1/guides/03_VisualDL/visualdl.html)。
+**A**:在配置文件里有`use_visualdl`的参数,设置为True即可,更多的使用命令可以参考:[VisualDL使用指南](https://www.paddlepaddle.org.cn/documentation/docs/zh/develop/guides/03_VisualDL/visualdl.html)。
#### Q3.1.28: 在使用StyleText数据合成工具的时候,报错`ModuleNotFoundError: No module named 'utils.config'`,这是为什么呢?
@@ -450,7 +451,7 @@ https://github.com/PaddlePaddle/PaddleOCR/blob/de3e2e7cd3b8b65ee02d7a41e570fa5b5
#### Q3.1.31: 怎么输出网络结构以及每层的参数信息?
-**A**:可以使用 `paddle.summary`, 具体参考:https://www.paddlepaddle.org.cn/documentation/docs/zh/2.0-rc1/api/paddle/hapi/model_summary/summary_cn.html#summary。
+**A**:可以使用 `paddle.summary`, 具体参考:https://www.paddlepaddle.org.cn/documentation/docs/zh/develop/api/paddle/hapi/model_summary/summary_cn.html。
#### Q3.1.32 能否修改StyleText配置文件中的分辨率?
@@ -474,9 +475,18 @@ StyleText的用途主要是:提取style_image中的字体、背景等style信
例如识别身份证照片,可以先匹配"姓名","性别"等关键字,根据这些关键字的坐标去推测其他信息的位置,再与识别的结果匹配。
#### Q3.1.36 如何识别竹简上的古文?
+
**A**:对于字符都是普通的汉字字符的情况,只要标注足够的数据,finetune模型就可以了。如果数据量不足,您可以尝试StyleText工具。
而如果使用的字符是特殊的古文字、甲骨文、象形文字等,那么首先需要构建一个古文字的字典,之后再进行训练。
+#### Q3.1.37: 小语种模型只有识别模型,没有检测模型吗?
+
+**A**:小语种(包括纯英文数字)的检测模型和中文的检测模型是共用的,在训练中文检测模型时加入了多语言数据。https://github.com/PaddlePaddle/PaddleOCR/blob/dygraph/doc/doc_en/models_list_en.md#1-text-detection-model。
+
+#### Q3.1.38: module 'paddle.distributed' has no attribute ‘get_rank’。
+
+**A**:Paddle版本问题,请安装2.0版本Paddle:pip install paddlepaddle==2.0.0。
+
### 数据集
@@ -854,3 +864,17 @@ img = cv.imdecode(img_array, -1)
#### Q3.4.29: DB文本检测,特征提取网络金字塔构建的部分代码在哪儿?
**A**:特征提取网络金字塔构建的部分:[代码位置](../../ppocr/modeling/necks/db_fpn.py)。ppocr/modeling文件夹里面是组网相关的代码,其中architectures是文本检测或者文本识别整体流程代码;backbones是骨干网络相关代码;necks是类似与FPN的颈函数代码;heads是提取文本检测或者文本识别预测结果相关的头函数;transforms是类似于TPS特征预处理模块。更多的信息可以参考[代码组织结构](./tree.md)。
+
+#### Q3.4.30: PaddleOCR是否支持在华为鲲鹏920CPU上部署?
+
+**A**:目前Paddle的预测库是支持华为鲲鹏920CPU的,但是OCR还没在这些芯片上测试过,可以自己调试,有问题反馈给我们。
+
+#### Q3.4.31: 采用Paddle-Lite进行端侧部署,出现问题,环境没问题。
+
+**A**:如果你的预测库是自己编译的,那么你的nb文件也要自己编译,用同一个lite版本。不能直接用下载的nb文件,因为版本不同。
+
+#### Q3.4.32: PaddleOCR的模型支持onnx转换吗?
+
+**A**:我们目前已经通过Paddle2ONNX来支持各模型套件的转换,PaddleOCR基于PaddlePaddle 2.0的版本(dygraph分支)已经支持导出为ONNX,欢迎关注Paddle2ONNX,了解更多项目的进展:
+Paddle2ONNX项目:https://github.com/PaddlePaddle/Paddle2ONNX
+Paddle2ONNX支持转换的[模型列表](https://github.com/PaddlePaddle/Paddle2ONNX/blob/develop/docs/zh/model_zoo.md#%E5%9B%BE%E5%83%8Focr)
diff --git a/doc/doc_ch/angle_class.md b/doc/doc_ch/angle_class.md
index 4d7ff0d7..6e68134a 100644
--- a/doc/doc_ch/angle_class.md
+++ b/doc/doc_ch/angle_class.md
@@ -63,7 +63,7 @@ PaddleOCR提供了训练脚本、评估脚本和预测脚本。
*如果您安装的是cpu版本,请将配置文件中的 `use_gpu` 字段修改为false*
```
-# GPU训练 支持单卡,多卡训练,通过 '--gpus' 指定卡号,如果使用的paddle版本小于2.0rc1,请使用'--select_gpus'参数选择要使用的GPU
+# GPU训练 支持单卡,多卡训练,通过 '--gpus' 指定卡号。
# 启动训练,下面的命令已经写入train.sh文件中,只需修改文件里的配置文件路径即可
python3 -m paddle.distributed.launch --gpus '0,1,2,3,4,5,6,7' tools/train.py -c configs/cls/cls_mv3.yml
```
diff --git a/doc/doc_ch/detection.md b/doc/doc_ch/detection.md
index 8f0f6979..a8dee65a 100644
--- a/doc/doc_ch/detection.md
+++ b/doc/doc_ch/detection.md
@@ -76,7 +76,7 @@ tar -xf ./pretrain_models/MobileNetV3_large_x0_5_pretrained.tar ./pretrain_model
# 单机单卡训练 mv3_db 模型
python3 tools/train.py -c configs/det/det_mv3_db.yml \
-o Global.pretrain_weights=./pretrain_models/MobileNetV3_large_x0_5_pretrained/
-# 单机多卡训练,通过 --gpus 参数设置使用的GPU ID;如果使用的paddle版本小于2.0rc1,请使用'--select_gpus'参数选择要使用的GPU
+# 单机多卡训练,通过 --gpus 参数设置使用的GPU ID
python3 -m paddle.distributed.launch --gpus '0,1,2,3' tools/train.py -c configs/det/det_mv3_db.yml \
-o Global.pretrain_weights=./pretrain_models/MobileNetV3_large_x0_5_pretrained/
```
diff --git a/doc/doc_ch/inference.md b/doc/doc_ch/inference.md
index ab548703..c4601e15 100755
--- a/doc/doc_ch/inference.md
+++ b/doc/doc_ch/inference.md
@@ -306,10 +306,10 @@ python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words_en/word_336.png
### 4. 多语言模型的推理
如果您需要预测的是其他语言模型,在使用inference模型预测时,需要通过`--rec_char_dict_path`指定使用的字典路径, 同时为了得到正确的可视化结果,
-需要通过 `--vis_font_path` 指定可视化的字体路径,`doc/` 路径下有默认提供的小语种字体,例如韩文识别:
+需要通过 `--vis_font_path` 指定可视化的字体路径,`doc/fonts/` 路径下有默认提供的小语种字体,例如韩文识别:
```
-python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words/korean/1.jpg" --rec_model_dir="./your inference model" --rec_char_type="korean" --rec_char_dict_path="ppocr/utils/dict/korean_dict.txt" --vis_font_path="doc/korean.ttf"
+python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words/korean/1.jpg" --rec_model_dir="./your inference model" --rec_char_type="korean" --rec_char_dict_path="ppocr/utils/dict/korean_dict.txt" --vis_font_path="doc/fonts/korean.ttf"
```
![](../imgs_words/korean/1.jpg)
diff --git a/doc/doc_ch/installation.md b/doc/doc_ch/installation.md
index 36565cd4..fce151eb 100644
--- a/doc/doc_ch/installation.md
+++ b/doc/doc_ch/installation.md
@@ -2,7 +2,7 @@
经测试PaddleOCR可在glibc 2.23上运行,您也可以测试其他glibc版本或安装glic 2.23
PaddleOCR 工作环境
-- PaddlePaddle 1.8+ ,推荐使用 PaddlePaddle 2.0rc1
+- PaddlePaddle 2.0.0
- python3.7
- glibc 2.23
- cuDNN 7.6+ (GPU)
@@ -35,11 +35,11 @@ sudo docker container exec -it ppocr /bin/bash
pip3 install --upgrade pip
如果您的机器安装的是CUDA9或CUDA10,请运行以下命令安装
-python3 -m pip install paddlepaddle-gpu==2.0.0rc1 -i https://mirror.baidu.com/pypi/simple
+python3 -m pip install paddlepaddle-gpu==2.0.0 -i https://mirror.baidu.com/pypi/simple
如果您的机器是CPU,请运行以下命令安装
-python3 -m pip install paddlepaddle==2.0.0rc1 -i https://mirror.baidu.com/pypi/simple
+python3 -m pip install paddlepaddle==2.0.0 -i https://mirror.baidu.com/pypi/simple
更多的版本需求,请参照[安装文档](https://www.paddlepaddle.org.cn/install/quick)中的说明进行操作。
```
diff --git a/doc/doc_ch/recognition.md b/doc/doc_ch/recognition.md
index b473f3ac..c5f459bd 100644
--- a/doc/doc_ch/recognition.md
+++ b/doc/doc_ch/recognition.md
@@ -195,8 +195,6 @@ PaddleOCR支持训练和评估交替进行, 可以在 `configs/rec/rec_icdar15_t
| :--------: | :-------: | :-------: | :-------: | :-----: | :-----: |
| [rec_chinese_lite_train_v2.0.yml](../../configs/rec/ch_ppocr_v2.0/rec_chinese_lite_train_v2.0.yml) | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc |
| [rec_chinese_common_train_v2.0.yml](../../configs/rec/ch_ppocr_v2.0/rec_chinese_common_train_v2.0.yml) | CRNN | ResNet34_vd | None | BiLSTM | ctc |
-| rec_chinese_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc |
-| rec_chinese_common_train.yml | CRNN | ResNet34_vd | None | BiLSTM | ctc |
| rec_icdar15_train.yml | CRNN | Mobilenet_v3 large 0.5 | None | BiLSTM | ctc |
| rec_mv3_none_bilstm_ctc.yml | CRNN | Mobilenet_v3 large 0.5 | None | BiLSTM | ctc |
| rec_mv3_none_none_ctc.yml | Rosetta | Mobilenet_v3 large 0.5 | None | None | ctc |
@@ -272,16 +270,109 @@ Eval:
- 小语种
-PaddleOCR也提供了多语言的, `configs/rec/multi_languages` 路径下的提供了多语言的配置文件,目前PaddleOCR支持的多语言算法有:
+PaddleOCR目前已支持26种(除中文外)语种识别,`configs/rec/multi_languages` 路径下提供了一个多语言的配置文件模版: [rec_multi_language_lite_train.yml](../../configs/rec/multi_language/rec_multi_language_lite_train.yml)。
-| 配置文件 | 算法名称 | backbone | trans | seq | pred | language |
-| :--------: | :-------: | :-------: | :-------: | :-----: | :-----: | :-----: |
-| rec_en_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | 英语 |
-| rec_french_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | 法语 |
-| rec_ger_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | 德语 |
-| rec_japan_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | 日语 |
-| rec_korean_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | 韩语 |
+您有两种方式创建所需的配置文件:
+1. 通过脚本自动生成
+
+[generate_multi_language_configs.py](../../configs/rec/multi_language/generate_multi_language_configs.py) 可以帮助您生成多语言模型的配置文件
+
+- 以意大利语为例,如果您的数据是按如下格式准备的:
+ ```
+ |-train_data
+ |- it_train.txt # 训练集标签
+ |- it_val.txt # 验证集标签
+ |- data
+ |- word_001.jpg
+ |- word_002.jpg
+ |- word_003.jpg
+ | ...
+ ```
+
+ 可以使用默认参数,生成配置文件:
+
+ ```bash
+ # 该代码需要在指定目录运行
+ cd PaddleOCR/configs/rec/multi_language/
+ # 通过-l或者--language参数设置需要生成的语种的配置文件,该命令会将默认参数写入配置文件
+ python3 generate_multi_language_configs.py -l it
+ ```
+
+- 如果您的数据放置在其他位置,或希望使用自己的字典,可以通过指定相关参数来生成配置文件:
+
+ ```bash
+ # -l或者--language字段是必须的
+ # --train修改训练集,--val修改验证集,--data_dir修改数据集目录,--dict修改字典路径, -o修改对应默认参数
+ cd PaddleOCR/configs/rec/multi_language/
+ python3 generate_multi_language_configs.py -l it \ # 语种
+ --train {path/of/train_label.txt} \ # 训练标签文件的路径
+ --val {path/of/val_label.txt} \ # 验证集标签文件的路径
+ --data_dir {train_data/path} \ # 训练数据的根目录
+ --dict {path/of/dict} \ # 字典文件路径
+ -o Global.use_gpu=False # 是否使用gpu
+ ...
+
+ ```
+
+2. 手动修改配置文件
+
+ 您也可以手动修改模版中的以下几个字段:
+
+ ```
+ Global:
+ use_gpu: True
+ epoch_num: 500
+ ...
+ character_type: it # 需要识别的语种
+ character_dict_path: {path/of/dict} # 字典文件所在路径
+
+ Train:
+ dataset:
+ name: SimpleDataSet
+ data_dir: train_data/ # 数据存放根目录
+ label_file_list: ["./train_data/train_list.txt"] # 训练集label路径
+ ...
+
+ Eval:
+ dataset:
+ name: SimpleDataSet
+ data_dir: train_data/ # 数据存放根目录
+ label_file_list: ["./train_data/val_list.txt"] # 验证集label路径
+ ...
+
+ ```
+
+目前PaddleOCR支持的多语言算法有:
+
+| 配置文件 | 算法名称 | backbone | trans | seq | pred | language | character_type |
+| :--------: | :-------: | :-------: | :-------: | :-----: | :-----: | :-----: | :-----: |
+| rec_chinese_cht_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | 中文繁体 | chinese_cht|
+| rec_en_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | 英语(区分大小写) | EN |
+| rec_french_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | 法语 | french |
+| rec_ger_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | 德语 | german |
+| rec_japan_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | 日语 | japan |
+| rec_korean_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | 韩语 | korean |
+| rec_it_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | 意大利语 | it |
+| rec_xi_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | 西班牙语 | xi |
+| rec_pu_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | 葡萄牙语 | pu |
+| rec_ru_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | 俄罗斯语 | ru |
+| rec_ar_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | 阿拉伯语 | ar |
+| rec_hi_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | 印地语 | hi |
+| rec_ug_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | 维吾尔语 | ug |
+| rec_fa_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | 波斯语 | fa |
+| rec_ur_ite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | 乌尔都语 | ur |
+| rec_rs_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | 塞尔维亚(latin)语 | rs |
+| rec_oc_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | 欧西坦语 | oc |
+| rec_mr_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | 马拉地语 | mr |
+| rec_ne_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | 尼泊尔语 | ne |
+| rec_rsc_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | 塞尔维亚(cyrillic)语 | rsc |
+| rec_bg_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | 保加利亚语 | bg |
+| rec_uk_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | 乌克兰语 | uk |
+| rec_be_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | 白俄罗斯语 | be |
+| rec_te_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | 泰卢固语 | te |
+| rec_ka_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | 卡纳达语 | ka |
+| rec_ta_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | 泰米尔语 | ta |
多语言模型训练方式与中文模型一致,训练数据集均为100w的合成数据,少量的字体可以在 [百度网盘](https://pan.baidu.com/s/1bS_u207Rm7YbY33wOECKDA) 上下载,提取码:frgi。
diff --git a/doc/doc_en/angle_class_en.md b/doc/doc_en/angle_class_en.md
index 8d932870..d1cc712f 100644
--- a/doc/doc_en/angle_class_en.md
+++ b/doc/doc_en/angle_class_en.md
@@ -66,7 +66,7 @@ Start training:
```
# Set PYTHONPATH path
export PYTHONPATH=$PYTHONPATH:.
-# GPU training Support single card and multi-card training, specify the card number through --gpus. If your paddle version is less than 2.0rc1, please use '--selected_gpus'
+# GPU training Support single card and multi-card training, specify the card number through --gpus.
# Start training, the following command has been written into the train.sh file, just modify the configuration file path in the file
python3 -m paddle.distributed.launch --gpus '0,1,2,3,4,5,6,7' tools/train.py -c configs/cls/cls_mv3.yml
```
diff --git a/doc/doc_en/detection_en.md b/doc/doc_en/detection_en.md
index 5c4a63e2..3ee9092c 100644
--- a/doc/doc_en/detection_en.md
+++ b/doc/doc_en/detection_en.md
@@ -76,7 +76,7 @@ You can also use `-o` to change the training parameters without modifying the ym
python3 tools/train.py -c configs/det/det_mv3_db.yml -o Optimizer.base_lr=0.0001
# multi-GPU training
-# Set the GPU ID used by the '--gpus' parameter; If your paddle version is less than 2.0rc1, please use '--selected_gpus'
+# Set the GPU ID used by the '--gpus' parameter.
python3 -m paddle.distributed.launch --gpus '0,1,2,3' tools/train.py -c configs/det/det_mv3_db.yml -o Optimizer.base_lr=0.0001
diff --git a/doc/doc_en/inference_en.md b/doc/doc_en/inference_en.md
index 98e3ef63..ccbb7184 100755
--- a/doc/doc_en/inference_en.md
+++ b/doc/doc_en/inference_en.md
@@ -315,10 +315,10 @@ python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words_en/word_336.png
### 4. MULTILINGAUL MODEL INFERENCE
If you need to predict other language models, when using inference model prediction, you need to specify the dictionary path used by `--rec_char_dict_path`. At the same time, in order to get the correct visualization results,
-You need to specify the visual font path through `--vis_font_path`. There are small language fonts provided by default under the `doc/` path, such as Korean recognition:
+You need to specify the visual font path through `--vis_font_path`. There are small language fonts provided by default under the `doc/fonts` path, such as Korean recognition:
```
-python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words/korean/1.jpg" --rec_model_dir="./your inference model" --rec_char_type="korean" --rec_char_dict_path="ppocr/utils/dict/korean_dict.txt" --vis_font_path="doc/korean.ttf"
+python3 tools/infer/predict_rec.py --image_dir="./doc/imgs_words/korean/1.jpg" --rec_model_dir="./your inference model" --rec_char_type="korean" --rec_char_dict_path="ppocr/utils/dict/korean_dict.txt" --vis_font_path="doc/fonts/korean.ttf"
```
![](../imgs_words/korean/1.jpg)
diff --git a/doc/doc_en/installation_en.md b/doc/doc_en/installation_en.md
index 7f1f0e83..35c1881d 100644
--- a/doc/doc_en/installation_en.md
+++ b/doc/doc_en/installation_en.md
@@ -3,7 +3,7 @@
After testing, paddleocr can run on glibc 2.23. You can also test other glibc versions or install glic 2.23 for the best compatibility.
PaddleOCR working environment:
-- PaddlePaddle 1.8+, Recommend PaddlePaddle 2.0rc1
+- PaddlePaddle 2.0.0
- python3.7
- glibc 2.23
@@ -38,10 +38,10 @@ sudo docker container exec -it ppocr /bin/bash
pip3 install --upgrade pip
# If you have cuda9 or cuda10 installed on your machine, please run the following command to install
-python3 -m pip install paddlepaddle-gpu==2.0rc1 -i https://mirror.baidu.com/pypi/simple
+python3 -m pip install paddlepaddle-gpu==2.0.0 -i https://mirror.baidu.com/pypi/simple
# If you only have cpu on your machine, please run the following command to install
-python3 -m pip install paddlepaddle==2.0rc1 -i https://mirror.baidu.com/pypi/simple
+python3 -m pip install paddlepaddle==2.0.0 -i https://mirror.baidu.com/pypi/simple
```
For more software version requirements, please refer to the instructions in [Installation Document](https://www.paddlepaddle.org.cn/install/quick) for operation.
diff --git a/doc/doc_en/models_list_en.md b/doc/doc_en/models_list_en.md
index 3eb0cd23..33033f83 100644
--- a/doc/doc_en/models_list_en.md
+++ b/doc/doc_en/models_list_en.md
@@ -93,7 +93,7 @@ python3 generate_multi_language_configs.py -l it \
|model name|description|config|model size|download|
| --- | --- | --- | --- | --- |
| french_mobile_v2.0_rec |Lightweight model for French recognition|[rec_french_lite_train.yml](../../configs/rec/multi_language/rec_french_lite_train.yml)|2.65M|[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/french_mobile_v2.0_rec_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/french_mobile_v2.0_rec_train.tar) |
-| german_mobile_v2.0_rec |Lightweight model for French recognition|[rec_german_lite_train.yml](../../configs/rec/multi_language/rec_german_lite_train.yml)|2.65M|[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/german_mobile_v2.0_rec_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/german_mobile_v2.0_rec_train.tar) |
+| german_mobile_v2.0_rec |Lightweight model for German recognition|[rec_german_lite_train.yml](../../configs/rec/multi_language/rec_german_lite_train.yml)|2.65M|[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/german_mobile_v2.0_rec_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/german_mobile_v2.0_rec_train.tar) |
| korean_mobile_v2.0_rec |Lightweight model for Korean recognition|[rec_korean_lite_train.yml](../../configs/rec/multi_language/rec_korean_lite_train.yml)|3.9M|[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/korean_mobile_v2.0_rec_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/korean_mobile_v2.0_rec_train.tar) |
| japan_mobile_v2.0_rec |Lightweight model for Japanese recognition|[rec_japan_lite_train.yml](../../configs/rec/multi_language/rec_japan_lite_train.yml)|4.23M|[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/japan_mobile_v2.0_rec_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/japan_mobile_v2.0_rec_train.tar) |
| it_mobile_v2.0_rec |Lightweight model for Italian recognition|rec_it_lite_train.yml|2.53M|[inference model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/it_mobile_v2.0_rec_infer.tar) / [trained model](https://paddleocr.bj.bcebos.com/dygraph_v2.0/multilingual/it_mobile_v2.0_rec_train.tar) |
diff --git a/doc/doc_en/recognition_en.md b/doc/doc_en/recognition_en.md
index 7723d20b..22f89cde 100644
--- a/doc/doc_en/recognition_en.md
+++ b/doc/doc_en/recognition_en.md
@@ -266,15 +266,116 @@ Eval:
- Multi-language
-PaddleOCR also provides multi-language. The configuration file in `configs/rec/multi_languages` provides multi-language configuration files. Currently, the multi-language algorithms supported by PaddleOCR are:
+PaddleOCR currently supports 26 (except Chinese) language recognition. A multi-language configuration file template is
+provided under the path `configs/rec/multi_languages`: [rec_multi_language_lite_train.yml](../../configs/rec/multi_language/rec_multi_language_lite_train.yml)。
+
+There are two ways to create the required configuration file::
+
+1. Automatically generated by script
+
+[generate_multi_language_configs.py](../../configs/rec/multi_language/generate_multi_language_configs.py) Can help you generate configuration files for multi-language models
+
+- Take Italian as an example, if your data is prepared in the following format:
+ ```
+ |-train_data
+ |- it_train.txt # train_set label
+ |- it_val.txt # val_set label
+ |- data
+ |- word_001.jpg
+ |- word_002.jpg
+ |- word_003.jpg
+ | ...
+ ```
+
+ You can use the default parameters to generate a configuration file:
+
+ ```bash
+ # The code needs to be run in the specified directory
+ cd PaddleOCR/configs/rec/multi_language/
+ # Set the configuration file of the language to be generated through the -l or --language parameter.
+ # This command will write the default parameters into the configuration file
+ python3 generate_multi_language_configs.py -l it
+ ```
+
+- If your data is placed in another location, or you want to use your own dictionary, you can generate the configuration file by specifying the relevant parameters:
+
+ ```bash
+ # -l or --language field is required
+ # --train to modify the training set
+ # --val to modify the validation set
+ # --data_dir to modify the data set directory
+ # --dict to modify the dict path
+ # -o to modify the corresponding default parameters
+ cd PaddleOCR/configs/rec/multi_language/
+ python3 generate_multi_language_configs.py -l it \ # language
+ --train {path/of/train_label.txt} \ # path of train_label
+ --val {path/of/val_label.txt} \ # path of val_label
+ --data_dir {train_data/path} \ # root directory of training data
+ --dict {path/of/dict} \ # path of dict
+ -o Global.use_gpu=False # whether to use gpu
+ ...
+
+ ```
+
+2. Manually modify the configuration file
+
+ You can also manually modify the following fields in the template:
+
+ ```
+ Global:
+ use_gpu: True
+ epoch_num: 500
+ ...
+ character_type: it # language
+ character_dict_path: {path/of/dict} # path of dict
+
+ Train:
+ dataset:
+ name: SimpleDataSet
+ data_dir: train_data/ # root directory of training data
+ label_file_list: ["./train_data/train_list.txt"] # train label path
+ ...
+
+ Eval:
+ dataset:
+ name: SimpleDataSet
+ data_dir: train_data/ # root directory of val data
+ label_file_list: ["./train_data/val_list.txt"] # val label path
+ ...
+
+ ```
+
+Currently, the multi-language algorithms supported by PaddleOCR are:
+
+| Configuration file | Algorithm name | backbone | trans | seq | pred | language | character_type |
+| :--------: | :-------: | :-------: | :-------: | :-----: | :-----: | :-----: | :-----: |
+| rec_chinese_cht_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | chinese traditional | chinese_cht|
+| rec_en_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | English(Case sensitive) | EN |
+| rec_french_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | French | french |
+| rec_ger_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | German | german |
+| rec_japan_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | Japanese | japan |
+| rec_korean_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | Korean | korean |
+| rec_it_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | Italian | it |
+| rec_xi_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | Spanish | xi |
+| rec_pu_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | Portuguese | pu |
+| rec_ru_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | Russia | ru |
+| rec_ar_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | Arabic | ar |
+| rec_hi_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | Hindi | hi |
+| rec_ug_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | Uyghur | ug |
+| rec_fa_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | Persian(Farsi) | fa |
+| rec_ur_ite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | Urdu | ur |
+| rec_rs_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | Serbian(latin) | rs |
+| rec_oc_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | Occitan | oc |
+| rec_mr_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | Marathi | mr |
+| rec_ne_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | Nepali | ne |
+| rec_rsc_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | Serbian(cyrillic) | rsc |
+| rec_bg_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | Bulgarian | bg |
+| rec_uk_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | Ukranian | uk |
+| rec_be_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | Belarusian | be |
+| rec_te_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | Telugu | te |
+| rec_ka_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | Kannada | ka |
+| rec_ta_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | Tamil | ta |
-| Configuration file | Algorithm name | backbone | trans | seq | pred | language |
-| :--------: | :-------: | :-------: | :-------: | :-----: | :-----: | :-----: |
-| rec_en_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | English |
-| rec_french_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | French |
-| rec_ger_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | German |
-| rec_japan_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | Japanese |
-| rec_korean_lite_train.yml | CRNN | Mobilenet_v3 small 0.5 | None | BiLSTM | ctc | Korean |
The multi-language model training method is the same as the Chinese model. The training data set is 100w synthetic data. A small amount of fonts and test data can be downloaded on [Baidu Netdisk](https://pan.baidu.com/s/1bS_u207Rm7YbY33wOECKDA),Extraction code:frgi.
diff --git a/doc/fonts/arabic.ttf b/doc/fonts/arabic.ttf
new file mode 100644
index 00000000..064b6041
Binary files /dev/null and b/doc/fonts/arabic.ttf differ
diff --git a/doc/fonts/chinese_cht.ttf b/doc/fonts/chinese_cht.ttf
new file mode 100644
index 00000000..3416754f
Binary files /dev/null and b/doc/fonts/chinese_cht.ttf differ
diff --git a/doc/fonts/cyrillic.ttf b/doc/fonts/cyrillic.ttf
new file mode 100644
index 00000000..be4bf660
Binary files /dev/null and b/doc/fonts/cyrillic.ttf differ
diff --git a/doc/french.ttf b/doc/fonts/french.ttf
similarity index 100%
rename from doc/french.ttf
rename to doc/fonts/french.ttf
diff --git a/doc/german.ttf b/doc/fonts/german.ttf
similarity index 100%
rename from doc/german.ttf
rename to doc/fonts/german.ttf
diff --git a/doc/fonts/hindi.ttf b/doc/fonts/hindi.ttf
new file mode 100644
index 00000000..8b0c36f5
Binary files /dev/null and b/doc/fonts/hindi.ttf differ
diff --git a/doc/japan.ttc b/doc/fonts/japan.ttc
similarity index 100%
rename from doc/japan.ttc
rename to doc/fonts/japan.ttc
diff --git a/doc/fonts/kannada.ttf b/doc/fonts/kannada.ttf
new file mode 100644
index 00000000..43b60d42
Binary files /dev/null and b/doc/fonts/kannada.ttf differ
diff --git a/doc/korean.ttf b/doc/fonts/korean.ttf
similarity index 100%
rename from doc/korean.ttf
rename to doc/fonts/korean.ttf
diff --git a/doc/fonts/latin.ttf b/doc/fonts/latin.ttf
new file mode 100644
index 00000000..e392413a
Binary files /dev/null and b/doc/fonts/latin.ttf differ
diff --git a/doc/fonts/marathi.ttf b/doc/fonts/marathi.ttf
new file mode 100644
index 00000000..a796d3ed
Binary files /dev/null and b/doc/fonts/marathi.ttf differ
diff --git a/doc/fonts/nepali.ttf b/doc/fonts/nepali.ttf
new file mode 100644
index 00000000..8b0c36f5
Binary files /dev/null and b/doc/fonts/nepali.ttf differ
diff --git a/doc/fonts/persian.ttf b/doc/fonts/persian.ttf
new file mode 100644
index 00000000..bdb1c8d7
Binary files /dev/null and b/doc/fonts/persian.ttf differ
diff --git a/doc/simfang.ttf b/doc/fonts/simfang.ttf
similarity index 100%
rename from doc/simfang.ttf
rename to doc/fonts/simfang.ttf
diff --git a/doc/fonts/spanish.ttf b/doc/fonts/spanish.ttf
new file mode 100644
index 00000000..532353d2
Binary files /dev/null and b/doc/fonts/spanish.ttf differ
diff --git a/doc/fonts/tamil.ttf b/doc/fonts/tamil.ttf
new file mode 100644
index 00000000..2e9998e8
Binary files /dev/null and b/doc/fonts/tamil.ttf differ
diff --git a/doc/fonts/telugu.ttf b/doc/fonts/telugu.ttf
new file mode 100644
index 00000000..12c91e41
Binary files /dev/null and b/doc/fonts/telugu.ttf differ
diff --git a/doc/fonts/urdu.ttf b/doc/fonts/urdu.ttf
new file mode 100644
index 00000000..625feee2
Binary files /dev/null and b/doc/fonts/urdu.ttf differ
diff --git a/doc/fonts/uyghur.ttf b/doc/fonts/uyghur.ttf
new file mode 100644
index 00000000..625feee2
Binary files /dev/null and b/doc/fonts/uyghur.ttf differ
diff --git a/ppocr/data/__init__.py b/ppocr/data/__init__.py
index 4e1ff0ae..7cb50d7a 100644
--- a/ppocr/data/__init__.py
+++ b/ppocr/data/__init__.py
@@ -33,7 +33,7 @@ import paddle.distributed as dist
from ppocr.data.imaug import transform, create_operators
from ppocr.data.simple_dataset import SimpleDataSet
-from ppocr.data.lmdb_dataset import LMDBDateSet
+from ppocr.data.lmdb_dataset import LMDBDataSet
__all__ = ['build_dataloader', 'transform', 'create_operators']
@@ -51,20 +51,21 @@ signal.signal(signal.SIGINT, term_mp)
signal.signal(signal.SIGTERM, term_mp)
-def build_dataloader(config, mode, device, logger):
+def build_dataloader(config, mode, device, logger, seed=None):
config = copy.deepcopy(config)
- support_dict = ['SimpleDataSet', 'LMDBDateSet']
+ support_dict = ['SimpleDataSet', 'LMDBDataSet']
module_name = config[mode]['dataset']['name']
assert module_name in support_dict, Exception(
'DataSet only support {}'.format(support_dict))
assert mode in ['Train', 'Eval', 'Test'
], "Mode should be Train, Eval or Test."
- dataset = eval(module_name)(config, mode, logger)
+ dataset = eval(module_name)(config, mode, logger, seed)
loader_config = config[mode]['loader']
batch_size = loader_config['batch_size_per_card']
drop_last = loader_config['drop_last']
+ shuffle = loader_config['shuffle']
num_workers = loader_config['num_workers']
if 'use_shared_memory' in loader_config.keys():
use_shared_memory = loader_config['use_shared_memory']
@@ -75,14 +76,14 @@ def build_dataloader(config, mode, device, logger):
batch_sampler = DistributedBatchSampler(
dataset=dataset,
batch_size=batch_size,
- shuffle=False,
+ shuffle=shuffle,
drop_last=drop_last)
else:
#Distribute data to single card
batch_sampler = BatchSampler(
dataset=dataset,
batch_size=batch_size,
- shuffle=False,
+ shuffle=shuffle,
drop_last=drop_last)
data_loader = DataLoader(
diff --git a/ppocr/data/imaug/__init__.py b/ppocr/data/imaug/__init__.py
index 6ea4dd8e..250ac75e 100644
--- a/ppocr/data/imaug/__init__.py
+++ b/ppocr/data/imaug/__init__.py
@@ -21,7 +21,7 @@ from .make_border_map import MakeBorderMap
from .make_shrink_map import MakeShrinkMap
from .random_crop_data import EastRandomCropData, PSERandomCrop
-from .rec_img_aug import RecAug, RecResizeImg, ClsResizeImg
+from .rec_img_aug import RecAug, RecResizeImg, ClsResizeImg, SRNRecResizeImg
from .randaugment import RandAugment
from .operators import *
from .label_ops import *
diff --git a/ppocr/data/imaug/label_ops.py b/ppocr/data/imaug/label_ops.py
index af3308a5..61c0c196 100644
--- a/ppocr/data/imaug/label_ops.py
+++ b/ppocr/data/imaug/label_ops.py
@@ -18,6 +18,7 @@ from __future__ import print_function
from __future__ import unicode_literals
import numpy as np
+import string
class ClsLabelEncode(object):
@@ -92,18 +93,28 @@ class BaseRecLabelEncode(object):
character_type='ch',
use_space_char=False):
support_character_type = [
- 'ch', 'en', 'en_sensitive', 'french', 'german', 'japan', 'korean'
+ 'ch', 'en', 'EN_symbol', 'french', 'german', 'japan', 'korean',
+ 'EN', 'it', 'xi', 'pu', 'ru', 'ar', 'ta', 'ug', 'fa', 'ur', 'rs',
+ 'oc', 'rsc', 'bg', 'uk', 'be', 'te', 'ka', 'chinese_cht', 'hi',
+ 'mr', 'ne'
]
assert character_type in support_character_type, "Only {} are supported now but get {}".format(
support_character_type, character_type)
self.max_text_len = max_text_length
+ self.beg_str = "sos"
+ self.end_str = "eos"
if character_type == "en":
self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz"
dict_character = list(self.character_str)
- elif character_type in ["ch", "french", "german", "japan", "korean"]:
+ elif character_type == "EN_symbol":
+ # same with ASTER setting (use 94 char).
+ self.character_str = string.printable[:-6]
+ dict_character = list(self.character_str)
+ elif character_type in support_character_type:
self.character_str = ""
- assert character_dict_path is not None, "character_dict_path should not be None when character_type is ch"
+ assert character_dict_path is not None, "character_dict_path should not be None when character_type is {}".format(
+ character_type)
with open(character_dict_path, "rb") as fin:
lines = fin.readlines()
for line in lines:
@@ -112,11 +123,6 @@ class BaseRecLabelEncode(object):
if use_space_char:
self.character_str += " "
dict_character = list(self.character_str)
- elif character_type == "en_sensitive":
- # same with ASTER setting (use 94 char).
- import string
- self.character_str = string.printable[:-6]
- dict_character = list(self.character_str)
self.character_type = character_type
dict_character = self.add_special_char(dict_character)
self.dict = {}
@@ -213,3 +219,49 @@ class AttnLabelEncode(BaseRecLabelEncode):
assert False, "Unsupport type %s in get_beg_end_flag_idx" \
% beg_or_end
return idx
+
+
+class SRNLabelEncode(BaseRecLabelEncode):
+ """ Convert between text-label and text-index """
+
+ def __init__(self,
+ max_text_length=25,
+ character_dict_path=None,
+ character_type='en',
+ use_space_char=False,
+ **kwargs):
+ super(SRNLabelEncode,
+ self).__init__(max_text_length, character_dict_path,
+ character_type, use_space_char)
+
+ def add_special_char(self, dict_character):
+ dict_character = dict_character + [self.beg_str, self.end_str]
+ return dict_character
+
+ def __call__(self, data):
+ text = data['label']
+ text = self.encode(text)
+ char_num = len(self.character_str)
+ if text is None:
+ return None
+ if len(text) > self.max_text_len:
+ return None
+ data['length'] = np.array(len(text))
+ text = text + [char_num] * (self.max_text_len - len(text))
+ data['label'] = np.array(text)
+ return data
+
+ def get_ignored_tokens(self):
+ beg_idx = self.get_beg_end_flag_idx("beg")
+ end_idx = self.get_beg_end_flag_idx("end")
+ return [beg_idx, end_idx]
+
+ def get_beg_end_flag_idx(self, beg_or_end):
+ if beg_or_end == "beg":
+ idx = np.array(self.dict[self.beg_str])
+ elif beg_or_end == "end":
+ idx = np.array(self.dict[self.end_str])
+ else:
+ assert False, "Unsupport type %s in get_beg_end_flag_idx" \
+ % beg_or_end
+ return idx
diff --git a/ppocr/data/imaug/rec_img_aug.py b/ppocr/data/imaug/rec_img_aug.py
index 2ccb2d1d..28e6bd0b 100644
--- a/ppocr/data/imaug/rec_img_aug.py
+++ b/ppocr/data/imaug/rec_img_aug.py
@@ -12,20 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
-# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
-#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-
import math
import cv2
import numpy as np
@@ -77,6 +63,26 @@ class RecResizeImg(object):
return data
+class SRNRecResizeImg(object):
+ def __init__(self, image_shape, num_heads, max_text_length, **kwargs):
+ self.image_shape = image_shape
+ self.num_heads = num_heads
+ self.max_text_length = max_text_length
+
+ def __call__(self, data):
+ img = data['image']
+ norm_img = resize_norm_img_srn(img, self.image_shape)
+ data['image'] = norm_img
+ [encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1, gsrm_slf_attn_bias2] = \
+ srn_other_inputs(self.image_shape, self.num_heads, self.max_text_length)
+
+ data['encoder_word_pos'] = encoder_word_pos
+ data['gsrm_word_pos'] = gsrm_word_pos
+ data['gsrm_slf_attn_bias1'] = gsrm_slf_attn_bias1
+ data['gsrm_slf_attn_bias2'] = gsrm_slf_attn_bias2
+ return data
+
+
def resize_norm_img(img, image_shape):
imgC, imgH, imgW = image_shape
h = img.shape[0]
@@ -103,7 +109,7 @@ def resize_norm_img(img, image_shape):
def resize_norm_img_chinese(img, image_shape):
imgC, imgH, imgW = image_shape
# todo: change to 0 and modified image shape
- max_wh_ratio = 0
+ max_wh_ratio = imgW * 1.0 / imgH
h, w = img.shape[0], img.shape[1]
ratio = w * 1.0 / h
max_wh_ratio = max(max_wh_ratio, ratio)
@@ -126,6 +132,60 @@ def resize_norm_img_chinese(img, image_shape):
return padding_im
+def resize_norm_img_srn(img, image_shape):
+ imgC, imgH, imgW = image_shape
+
+ img_black = np.zeros((imgH, imgW))
+ im_hei = img.shape[0]
+ im_wid = img.shape[1]
+
+ if im_wid <= im_hei * 1:
+ img_new = cv2.resize(img, (imgH * 1, imgH))
+ elif im_wid <= im_hei * 2:
+ img_new = cv2.resize(img, (imgH * 2, imgH))
+ elif im_wid <= im_hei * 3:
+ img_new = cv2.resize(img, (imgH * 3, imgH))
+ else:
+ img_new = cv2.resize(img, (imgW, imgH))
+
+ img_np = np.asarray(img_new)
+ img_np = cv2.cvtColor(img_np, cv2.COLOR_BGR2GRAY)
+ img_black[:, 0:img_np.shape[1]] = img_np
+ img_black = img_black[:, :, np.newaxis]
+
+ row, col, c = img_black.shape
+ c = 1
+
+ return np.reshape(img_black, (c, row, col)).astype(np.float32)
+
+
+def srn_other_inputs(image_shape, num_heads, max_text_length):
+
+ imgC, imgH, imgW = image_shape
+ feature_dim = int((imgH / 8) * (imgW / 8))
+
+ encoder_word_pos = np.array(range(0, feature_dim)).reshape(
+ (feature_dim, 1)).astype('int64')
+ gsrm_word_pos = np.array(range(0, max_text_length)).reshape(
+ (max_text_length, 1)).astype('int64')
+
+ gsrm_attn_bias_data = np.ones((1, max_text_length, max_text_length))
+ gsrm_slf_attn_bias1 = np.triu(gsrm_attn_bias_data, 1).reshape(
+ [1, max_text_length, max_text_length])
+ gsrm_slf_attn_bias1 = np.tile(gsrm_slf_attn_bias1,
+ [num_heads, 1, 1]) * [-1e9]
+
+ gsrm_slf_attn_bias2 = np.tril(gsrm_attn_bias_data, -1).reshape(
+ [1, max_text_length, max_text_length])
+ gsrm_slf_attn_bias2 = np.tile(gsrm_slf_attn_bias2,
+ [num_heads, 1, 1]) * [-1e9]
+
+ return [
+ encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1,
+ gsrm_slf_attn_bias2
+ ]
+
+
def flag():
"""
flag
diff --git a/ppocr/data/lmdb_dataset.py b/ppocr/data/lmdb_dataset.py
index e7bb6dd3..bd0630f6 100644
--- a/ppocr/data/lmdb_dataset.py
+++ b/ppocr/data/lmdb_dataset.py
@@ -21,7 +21,7 @@ from .imaug import transform, create_operators
class LMDBDateSet(Dataset):
- def __init__(self, config, mode, logger):
+ def __init__(self, config, mode, logger, seed=None):
super(LMDBDateSet, self).__init__()
global_config = config['Global']
diff --git a/ppocr/data/simple_dataset.py b/ppocr/data/simple_dataset.py
index ab17dd1a..d2a86b0f 100644
--- a/ppocr/data/simple_dataset.py
+++ b/ppocr/data/simple_dataset.py
@@ -20,7 +20,7 @@ from .imaug import transform, create_operators
class SimpleDataSet(Dataset):
- def __init__(self, config, mode, logger):
+ def __init__(self, config, mode, logger, seed=None):
super(SimpleDataSet, self).__init__()
self.logger = logger
@@ -41,6 +41,7 @@ class SimpleDataSet(Dataset):
self.data_dir = dataset_config['data_dir']
self.do_shuffle = loader_config['shuffle']
+ self.seed = seed
logger.info("Initialize indexs of datasets:%s" % label_file_list)
self.data_lines = self.get_image_info_list(label_file_list, ratio_list)
self.data_idx_order_list = list(range(len(self.data_lines)))
@@ -55,6 +56,7 @@ class SimpleDataSet(Dataset):
for idx, file in enumerate(file_list):
with open(file, "rb") as f:
lines = f.readlines()
+ random.seed(self.seed)
lines = random.sample(lines,
round(len(lines) * ratio_list[idx]))
data_lines.extend(lines)
@@ -62,6 +64,7 @@ class SimpleDataSet(Dataset):
def shuffle_data_random(self):
if self.do_shuffle:
+ random.seed(self.seed)
random.shuffle(self.data_lines)
return
diff --git a/ppocr/losses/__init__.py b/ppocr/losses/__init__.py
index 4673d35c..b280eb33 100755
--- a/ppocr/losses/__init__.py
+++ b/ppocr/losses/__init__.py
@@ -23,11 +23,14 @@ def build_loss(config):
# rec loss
from .rec_ctc_loss import CTCLoss
+ from .rec_srn_loss import SRNLoss
# cls loss
from .cls_loss import ClsLoss
- support_dict = ['DBLoss', 'EASTLoss', 'SASTLoss', 'CTCLoss', 'ClsLoss']
+ support_dict = [
+ 'DBLoss', 'EASTLoss', 'SASTLoss', 'CTCLoss', 'ClsLoss', 'SRNLoss'
+ ]
config = copy.deepcopy(config)
module_name = config.pop('name')
diff --git a/ppocr/losses/rec_srn_loss.py b/ppocr/losses/rec_srn_loss.py
new file mode 100644
index 00000000..7d5b65eb
--- /dev/null
+++ b/ppocr/losses/rec_srn_loss.py
@@ -0,0 +1,47 @@
+# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import paddle
+from paddle import nn
+
+
+class SRNLoss(nn.Layer):
+ def __init__(self, **kwargs):
+ super(SRNLoss, self).__init__()
+ self.loss_func = paddle.nn.loss.CrossEntropyLoss(reduction="sum")
+
+ def forward(self, predicts, batch):
+ predict = predicts['predict']
+ word_predict = predicts['word_out']
+ gsrm_predict = predicts['gsrm_out']
+ label = batch[1]
+
+ casted_label = paddle.cast(x=label, dtype='int64')
+ casted_label = paddle.reshape(x=casted_label, shape=[-1, 1])
+
+ cost_word = self.loss_func(word_predict, label=casted_label)
+ cost_gsrm = self.loss_func(gsrm_predict, label=casted_label)
+ cost_vsfd = self.loss_func(predict, label=casted_label)
+
+ cost_word = paddle.reshape(x=paddle.sum(cost_word), shape=[1])
+ cost_gsrm = paddle.reshape(x=paddle.sum(cost_gsrm), shape=[1])
+ cost_vsfd = paddle.reshape(x=paddle.sum(cost_vsfd), shape=[1])
+
+ sum_cost = cost_word * 3.0 + cost_vsfd + cost_gsrm * 0.15
+
+ return {'loss': sum_cost, 'word_loss': cost_word, 'img_loss': cost_vsfd}
diff --git a/ppocr/metrics/rec_metric.py b/ppocr/metrics/rec_metric.py
index a86fc838..b3aa9f38 100644
--- a/ppocr/metrics/rec_metric.py
+++ b/ppocr/metrics/rec_metric.py
@@ -33,8 +33,6 @@ class RecMetric(object):
if pred == target:
correct_num += 1
all_num += 1
- # if all_num < 10 and kwargs.get('show_str', False):
- # print('{} -> {}'.format(pred, target))
self.correct_num += correct_num
self.all_num += all_num
self.norm_edit_dis += norm_edit_dis
@@ -50,7 +48,7 @@ class RecMetric(object):
'norm_edit_dis': 0,
}
"""
- acc = self.correct_num / self.all_num
+ acc = 1.0 * self.correct_num / self.all_num
norm_edit_dis = 1 - self.norm_edit_dis / self.all_num
self.reset()
return {'acc': acc, 'norm_edit_dis': norm_edit_dis}
diff --git a/ppocr/modeling/architectures/base_model.py b/ppocr/modeling/architectures/base_model.py
index ab44b53a..09b6e034 100644
--- a/ppocr/modeling/architectures/base_model.py
+++ b/ppocr/modeling/architectures/base_model.py
@@ -68,11 +68,14 @@ class BaseModel(nn.Layer):
config["Head"]['in_channels'] = in_channels
self.head = build_head(config["Head"])
- def forward(self, x):
+ def forward(self, x, data=None):
if self.use_transform:
x = self.transform(x)
x = self.backbone(x)
if self.use_neck:
x = self.neck(x)
- x = self.head(x)
+ if data is None:
+ x = self.head(x)
+ else:
+ x = self.head(x, data)
return x
diff --git a/ppocr/modeling/backbones/__init__.py b/ppocr/modeling/backbones/__init__.py
index 43103e53..03c15508 100755
--- a/ppocr/modeling/backbones/__init__.py
+++ b/ppocr/modeling/backbones/__init__.py
@@ -24,7 +24,8 @@ def build_backbone(config, model_type):
elif model_type == 'rec' or model_type == 'cls':
from .rec_mobilenet_v3 import MobileNetV3
from .rec_resnet_vd import ResNet
- support_dict = ['MobileNetV3', 'ResNet', 'ResNet_FPN']
+ from .rec_resnet_fpn import ResNetFPN
+ support_dict = ['MobileNetV3', 'ResNet', 'ResNetFPN']
else:
raise NotImplementedError
diff --git a/ppocr/modeling/backbones/det_mobilenet_v3.py b/ppocr/modeling/backbones/det_mobilenet_v3.py
index f97bcfca..bb451bbe 100755
--- a/ppocr/modeling/backbones/det_mobilenet_v3.py
+++ b/ppocr/modeling/backbones/det_mobilenet_v3.py
@@ -58,15 +58,15 @@ class MobileNetV3(nn.Layer):
[5, 72, 40, True, 'relu', 2],
[5, 120, 40, True, 'relu', 1],
[5, 120, 40, True, 'relu', 1],
- [3, 240, 80, False, 'hard_swish', 2],
- [3, 200, 80, False, 'hard_swish', 1],
- [3, 184, 80, False, 'hard_swish', 1],
- [3, 184, 80, False, 'hard_swish', 1],
- [3, 480, 112, True, 'hard_swish', 1],
- [3, 672, 112, True, 'hard_swish', 1],
- [5, 672, 160, True, 'hard_swish', 2],
- [5, 960, 160, True, 'hard_swish', 1],
- [5, 960, 160, True, 'hard_swish', 1],
+ [3, 240, 80, False, 'hardswish', 2],
+ [3, 200, 80, False, 'hardswish', 1],
+ [3, 184, 80, False, 'hardswish', 1],
+ [3, 184, 80, False, 'hardswish', 1],
+ [3, 480, 112, True, 'hardswish', 1],
+ [3, 672, 112, True, 'hardswish', 1],
+ [5, 672, 160, True, 'hardswish', 2],
+ [5, 960, 160, True, 'hardswish', 1],
+ [5, 960, 160, True, 'hardswish', 1],
]
cls_ch_squeeze = 960
elif model_name == "small":
@@ -75,14 +75,14 @@ class MobileNetV3(nn.Layer):
[3, 16, 16, True, 'relu', 2],
[3, 72, 24, False, 'relu', 2],
[3, 88, 24, False, 'relu', 1],
- [5, 96, 40, True, 'hard_swish', 2],
- [5, 240, 40, True, 'hard_swish', 1],
- [5, 240, 40, True, 'hard_swish', 1],
- [5, 120, 48, True, 'hard_swish', 1],
- [5, 144, 48, True, 'hard_swish', 1],
- [5, 288, 96, True, 'hard_swish', 2],
- [5, 576, 96, True, 'hard_swish', 1],
- [5, 576, 96, True, 'hard_swish', 1],
+ [5, 96, 40, True, 'hardswish', 2],
+ [5, 240, 40, True, 'hardswish', 1],
+ [5, 240, 40, True, 'hardswish', 1],
+ [5, 120, 48, True, 'hardswish', 1],
+ [5, 144, 48, True, 'hardswish', 1],
+ [5, 288, 96, True, 'hardswish', 2],
+ [5, 576, 96, True, 'hardswish', 1],
+ [5, 576, 96, True, 'hardswish', 1],
]
cls_ch_squeeze = 576
else:
@@ -102,7 +102,7 @@ class MobileNetV3(nn.Layer):
padding=1,
groups=1,
if_act=True,
- act='hard_swish',
+ act='hardswish',
name='conv1')
self.stages = []
@@ -112,7 +112,8 @@ class MobileNetV3(nn.Layer):
inplanes = make_divisible(inplanes * scale)
for (k, exp, c, se, nl, s) in cfg:
se = se and not self.disable_se
- if s == 2 and i > 2:
+ start_idx = 2 if model_name == 'large' else 0
+ if s == 2 and i > start_idx:
self.out_channels.append(inplanes)
self.stages.append(nn.Sequential(*block_list))
block_list = []
@@ -137,7 +138,7 @@ class MobileNetV3(nn.Layer):
padding=0,
groups=1,
if_act=True,
- act='hard_swish',
+ act='hardswish',
name='conv_last'))
self.stages.append(nn.Sequential(*block_list))
self.out_channels.append(make_divisible(scale * cls_ch_squeeze))
@@ -191,10 +192,11 @@ class ConvBNLayer(nn.Layer):
if self.if_act:
if self.act == "relu":
x = F.relu(x)
- elif self.act == "hard_swish":
- x = F.activation.hard_swish(x)
+ elif self.act == "hardswish":
+ x = F.hardswish(x)
else:
- print("The activation function is selected incorrectly.")
+ print("The activation function({}) is selected incorrectly.".
+ format(self.act))
exit()
return x
@@ -281,5 +283,5 @@ class SEModule(nn.Layer):
outputs = self.conv1(outputs)
outputs = F.relu(outputs)
outputs = self.conv2(outputs)
- outputs = F.activation.hard_sigmoid(outputs)
+ outputs = F.hardsigmoid(outputs, slope=0.2, offset=0.5)
return inputs * outputs
diff --git a/ppocr/modeling/backbones/rec_mobilenet_v3.py b/ppocr/modeling/backbones/rec_mobilenet_v3.py
index bdf4a616..1ff17159 100644
--- a/ppocr/modeling/backbones/rec_mobilenet_v3.py
+++ b/ppocr/modeling/backbones/rec_mobilenet_v3.py
@@ -51,15 +51,15 @@ class MobileNetV3(nn.Layer):
[5, 72, 40, True, 'relu', (large_stride[2], 1)],
[5, 120, 40, True, 'relu', 1],
[5, 120, 40, True, 'relu', 1],
- [3, 240, 80, False, 'hard_swish', 1],
- [3, 200, 80, False, 'hard_swish', 1],
- [3, 184, 80, False, 'hard_swish', 1],
- [3, 184, 80, False, 'hard_swish', 1],
- [3, 480, 112, True, 'hard_swish', 1],
- [3, 672, 112, True, 'hard_swish', 1],
- [5, 672, 160, True, 'hard_swish', (large_stride[3], 1)],
- [5, 960, 160, True, 'hard_swish', 1],
- [5, 960, 160, True, 'hard_swish', 1],
+ [3, 240, 80, False, 'hardswish', 1],
+ [3, 200, 80, False, 'hardswish', 1],
+ [3, 184, 80, False, 'hardswish', 1],
+ [3, 184, 80, False, 'hardswish', 1],
+ [3, 480, 112, True, 'hardswish', 1],
+ [3, 672, 112, True, 'hardswish', 1],
+ [5, 672, 160, True, 'hardswish', (large_stride[3], 1)],
+ [5, 960, 160, True, 'hardswish', 1],
+ [5, 960, 160, True, 'hardswish', 1],
]
cls_ch_squeeze = 960
elif model_name == "small":
@@ -68,14 +68,14 @@ class MobileNetV3(nn.Layer):
[3, 16, 16, True, 'relu', (small_stride[0], 1)],
[3, 72, 24, False, 'relu', (small_stride[1], 1)],
[3, 88, 24, False, 'relu', 1],
- [5, 96, 40, True, 'hard_swish', (small_stride[2], 1)],
- [5, 240, 40, True, 'hard_swish', 1],
- [5, 240, 40, True, 'hard_swish', 1],
- [5, 120, 48, True, 'hard_swish', 1],
- [5, 144, 48, True, 'hard_swish', 1],
- [5, 288, 96, True, 'hard_swish', (small_stride[3], 1)],
- [5, 576, 96, True, 'hard_swish', 1],
- [5, 576, 96, True, 'hard_swish', 1],
+ [5, 96, 40, True, 'hardswish', (small_stride[2], 1)],
+ [5, 240, 40, True, 'hardswish', 1],
+ [5, 240, 40, True, 'hardswish', 1],
+ [5, 120, 48, True, 'hardswish', 1],
+ [5, 144, 48, True, 'hardswish', 1],
+ [5, 288, 96, True, 'hardswish', (small_stride[3], 1)],
+ [5, 576, 96, True, 'hardswish', 1],
+ [5, 576, 96, True, 'hardswish', 1],
]
cls_ch_squeeze = 576
else:
@@ -96,7 +96,7 @@ class MobileNetV3(nn.Layer):
padding=1,
groups=1,
if_act=True,
- act='hard_swish',
+ act='hardswish',
name='conv1')
i = 0
block_list = []
@@ -124,7 +124,7 @@ class MobileNetV3(nn.Layer):
padding=0,
groups=1,
if_act=True,
- act='hard_swish',
+ act='hardswish',
name='conv_last')
self.pool = nn.MaxPool2D(kernel_size=2, stride=2, padding=0)
diff --git a/ppocr/modeling/backbones/rec_resnet_fpn.py b/ppocr/modeling/backbones/rec_resnet_fpn.py
new file mode 100644
index 00000000..a7e876a2
--- /dev/null
+++ b/ppocr/modeling/backbones/rec_resnet_fpn.py
@@ -0,0 +1,307 @@
+#copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+#
+#Licensed under the Apache License, Version 2.0 (the "License");
+#you may not use this file except in compliance with the License.
+#You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+#Unless required by applicable law or agreed to in writing, software
+#distributed under the License is distributed on an "AS IS" BASIS,
+#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#See the License for the specific language governing permissions and
+#limitations under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from paddle import nn, ParamAttr
+from paddle.nn import functional as F
+import paddle.fluid as fluid
+import paddle
+import numpy as np
+
+__all__ = ["ResNetFPN"]
+
+
+class ResNetFPN(nn.Layer):
+ def __init__(self, in_channels=1, layers=50, **kwargs):
+ super(ResNetFPN, self).__init__()
+ supported_layers = {
+ 18: {
+ 'depth': [2, 2, 2, 2],
+ 'block_class': BasicBlock
+ },
+ 34: {
+ 'depth': [3, 4, 6, 3],
+ 'block_class': BasicBlock
+ },
+ 50: {
+ 'depth': [3, 4, 6, 3],
+ 'block_class': BottleneckBlock
+ },
+ 101: {
+ 'depth': [3, 4, 23, 3],
+ 'block_class': BottleneckBlock
+ },
+ 152: {
+ 'depth': [3, 8, 36, 3],
+ 'block_class': BottleneckBlock
+ }
+ }
+ stride_list = [(2, 2), (2, 2), (1, 1), (1, 1)]
+ num_filters = [64, 128, 256, 512]
+ self.depth = supported_layers[layers]['depth']
+ self.F = []
+ self.conv = ConvBNLayer(
+ in_channels=in_channels,
+ out_channels=64,
+ kernel_size=7,
+ stride=2,
+ act="relu",
+ name="conv1")
+ self.block_list = []
+ in_ch = 64
+ if layers >= 50:
+ for block in range(len(self.depth)):
+ for i in range(self.depth[block]):
+ if layers in [101, 152] and block == 2:
+ if i == 0:
+ conv_name = "res" + str(block + 2) + "a"
+ else:
+ conv_name = "res" + str(block + 2) + "b" + str(i)
+ else:
+ conv_name = "res" + str(block + 2) + chr(97 + i)
+ block_list = self.add_sublayer(
+ "bottleneckBlock_{}_{}".format(block, i),
+ BottleneckBlock(
+ in_channels=in_ch,
+ out_channels=num_filters[block],
+ stride=stride_list[block] if i == 0 else 1,
+ name=conv_name))
+ in_ch = num_filters[block] * 4
+ self.block_list.append(block_list)
+ self.F.append(block_list)
+ else:
+ for block in range(len(self.depth)):
+ for i in range(self.depth[block]):
+ conv_name = "res" + str(block + 2) + chr(97 + i)
+ if i == 0 and block != 0:
+ stride = (2, 1)
+ else:
+ stride = (1, 1)
+ basic_block = self.add_sublayer(
+ conv_name,
+ BasicBlock(
+ in_channels=in_ch,
+ out_channels=num_filters[block],
+ stride=stride_list[block] if i == 0 else 1,
+ is_first=block == i == 0,
+ name=conv_name))
+ in_ch = basic_block.out_channels
+ self.block_list.append(basic_block)
+ out_ch_list = [in_ch // 4, in_ch // 2, in_ch]
+ self.base_block = []
+ self.conv_trans = []
+ self.bn_block = []
+ for i in [-2, -3]:
+ in_channels = out_ch_list[i + 1] + out_ch_list[i]
+
+ self.base_block.append(
+ self.add_sublayer(
+ "F_{}_base_block_0".format(i),
+ nn.Conv2D(
+ in_channels=in_channels,
+ out_channels=out_ch_list[i],
+ kernel_size=1,
+ weight_attr=ParamAttr(trainable=True),
+ bias_attr=ParamAttr(trainable=True))))
+ self.base_block.append(
+ self.add_sublayer(
+ "F_{}_base_block_1".format(i),
+ nn.Conv2D(
+ in_channels=out_ch_list[i],
+ out_channels=out_ch_list[i],
+ kernel_size=3,
+ padding=1,
+ weight_attr=ParamAttr(trainable=True),
+ bias_attr=ParamAttr(trainable=True))))
+ self.base_block.append(
+ self.add_sublayer(
+ "F_{}_base_block_2".format(i),
+ nn.BatchNorm(
+ num_channels=out_ch_list[i],
+ act="relu",
+ param_attr=ParamAttr(trainable=True),
+ bias_attr=ParamAttr(trainable=True))))
+ self.base_block.append(
+ self.add_sublayer(
+ "F_{}_base_block_3".format(i),
+ nn.Conv2D(
+ in_channels=out_ch_list[i],
+ out_channels=512,
+ kernel_size=1,
+ bias_attr=ParamAttr(trainable=True),
+ weight_attr=ParamAttr(trainable=True))))
+ self.out_channels = 512
+
+ def __call__(self, x):
+ x = self.conv(x)
+ fpn_list = []
+ F = []
+ for i in range(len(self.depth)):
+ fpn_list.append(np.sum(self.depth[:i + 1]))
+
+ for i, block in enumerate(self.block_list):
+ x = block(x)
+ for number in fpn_list:
+ if i + 1 == number:
+ F.append(x)
+ base = F[-1]
+
+ j = 0
+ for i, block in enumerate(self.base_block):
+ if i % 3 == 0 and i < 6:
+ j = j + 1
+ b, c, w, h = F[-j - 1].shape
+ if [w, h] == list(base.shape[2:]):
+ base = base
+ else:
+ base = self.conv_trans[j - 1](base)
+ base = self.bn_block[j - 1](base)
+ base = paddle.concat([base, F[-j - 1]], axis=1)
+ base = block(base)
+ return base
+
+
+class ConvBNLayer(nn.Layer):
+ def __init__(self,
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride=1,
+ groups=1,
+ act=None,
+ name=None):
+ super(ConvBNLayer, self).__init__()
+ self.conv = nn.Conv2D(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=2 if stride == (1, 1) else kernel_size,
+ dilation=2 if stride == (1, 1) else 1,
+ stride=stride,
+ padding=(kernel_size - 1) // 2,
+ groups=groups,
+ weight_attr=ParamAttr(name=name + '.conv2d.output.1.w_0'),
+ bias_attr=False, )
+
+ if name == "conv1":
+ bn_name = "bn_" + name
+ else:
+ bn_name = "bn" + name[3:]
+ self.bn = nn.BatchNorm(
+ num_channels=out_channels,
+ act=act,
+ param_attr=ParamAttr(name=name + '.output.1.w_0'),
+ bias_attr=ParamAttr(name=name + '.output.1.b_0'),
+ moving_mean_name=bn_name + "_mean",
+ moving_variance_name=bn_name + "_variance")
+
+ def __call__(self, x):
+ x = self.conv(x)
+ x = self.bn(x)
+ return x
+
+
+class ShortCut(nn.Layer):
+ def __init__(self, in_channels, out_channels, stride, name, is_first=False):
+ super(ShortCut, self).__init__()
+ self.use_conv = True
+
+ if in_channels != out_channels or stride != 1 or is_first == True:
+ if stride == (1, 1):
+ self.conv = ConvBNLayer(
+ in_channels, out_channels, 1, 1, name=name)
+ else: # stride==(2,2)
+ self.conv = ConvBNLayer(
+ in_channels, out_channels, 1, stride, name=name)
+ else:
+ self.use_conv = False
+
+ def forward(self, x):
+ if self.use_conv:
+ x = self.conv(x)
+ return x
+
+
+class BottleneckBlock(nn.Layer):
+ def __init__(self, in_channels, out_channels, stride, name):
+ super(BottleneckBlock, self).__init__()
+ self.conv0 = ConvBNLayer(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=1,
+ act='relu',
+ name=name + "_branch2a")
+ self.conv1 = ConvBNLayer(
+ in_channels=out_channels,
+ out_channels=out_channels,
+ kernel_size=3,
+ stride=stride,
+ act='relu',
+ name=name + "_branch2b")
+
+ self.conv2 = ConvBNLayer(
+ in_channels=out_channels,
+ out_channels=out_channels * 4,
+ kernel_size=1,
+ act=None,
+ name=name + "_branch2c")
+
+ self.short = ShortCut(
+ in_channels=in_channels,
+ out_channels=out_channels * 4,
+ stride=stride,
+ is_first=False,
+ name=name + "_branch1")
+ self.out_channels = out_channels * 4
+
+ def forward(self, x):
+ y = self.conv0(x)
+ y = self.conv1(y)
+ y = self.conv2(y)
+ y = y + self.short(x)
+ y = F.relu(y)
+ return y
+
+
+class BasicBlock(nn.Layer):
+ def __init__(self, in_channels, out_channels, stride, name, is_first):
+ super(BasicBlock, self).__init__()
+ self.conv0 = ConvBNLayer(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=3,
+ act='relu',
+ stride=stride,
+ name=name + "_branch2a")
+ self.conv1 = ConvBNLayer(
+ in_channels=out_channels,
+ out_channels=out_channels,
+ kernel_size=3,
+ act=None,
+ name=name + "_branch2b")
+ self.short = ShortCut(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ stride=stride,
+ is_first=is_first,
+ name=name + "_branch1")
+ self.out_channels = out_channels
+
+ def forward(self, x):
+ y = self.conv0(x)
+ y = self.conv1(y)
+ y = y + self.short(x)
+ return F.relu(y)
diff --git a/ppocr/modeling/heads/__init__.py b/ppocr/modeling/heads/__init__.py
index 78074709..1a39ca41 100755
--- a/ppocr/modeling/heads/__init__.py
+++ b/ppocr/modeling/heads/__init__.py
@@ -23,10 +23,13 @@ def build_head(config):
# rec head
from .rec_ctc_head import CTCHead
+ from .rec_srn_head import SRNHead
# cls head
from .cls_head import ClsHead
- support_dict = ['DBHead', 'EASTHead', 'SASTHead', 'CTCHead', 'ClsHead']
+ support_dict = [
+ 'DBHead', 'EASTHead', 'SASTHead', 'CTCHead', 'ClsHead', 'SRNHead'
+ ]
module_name = config.pop('name')
assert module_name in support_dict, Exception('head only support {}'.format(
diff --git a/ppocr/modeling/heads/rec_srn_head.py b/ppocr/modeling/heads/rec_srn_head.py
new file mode 100644
index 00000000..8aaf65e1
--- /dev/null
+++ b/ppocr/modeling/heads/rec_srn_head.py
@@ -0,0 +1,279 @@
+# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import math
+import paddle
+from paddle import nn, ParamAttr
+from paddle.nn import functional as F
+import paddle.fluid as fluid
+import numpy as np
+from .self_attention import WrapEncoderForFeature
+from .self_attention import WrapEncoder
+from paddle.static import Program
+from ppocr.modeling.backbones.rec_resnet_fpn import ResNetFPN
+import paddle.fluid.framework as framework
+
+from collections import OrderedDict
+gradient_clip = 10
+
+
+class PVAM(nn.Layer):
+ def __init__(self, in_channels, char_num, max_text_length, num_heads,
+ num_encoder_tus, hidden_dims):
+ super(PVAM, self).__init__()
+ self.char_num = char_num
+ self.max_length = max_text_length
+ self.num_heads = num_heads
+ self.num_encoder_TUs = num_encoder_tus
+ self.hidden_dims = hidden_dims
+ # Transformer encoder
+ t = 256
+ c = 512
+ self.wrap_encoder_for_feature = WrapEncoderForFeature(
+ src_vocab_size=1,
+ max_length=t,
+ n_layer=self.num_encoder_TUs,
+ n_head=self.num_heads,
+ d_key=int(self.hidden_dims / self.num_heads),
+ d_value=int(self.hidden_dims / self.num_heads),
+ d_model=self.hidden_dims,
+ d_inner_hid=self.hidden_dims,
+ prepostprocess_dropout=0.1,
+ attention_dropout=0.1,
+ relu_dropout=0.1,
+ preprocess_cmd="n",
+ postprocess_cmd="da",
+ weight_sharing=True)
+
+ # PVAM
+ self.flatten0 = paddle.nn.Flatten(start_axis=0, stop_axis=1)
+ self.fc0 = paddle.nn.Linear(
+ in_features=in_channels,
+ out_features=in_channels, )
+ self.emb = paddle.nn.Embedding(
+ num_embeddings=self.max_length, embedding_dim=in_channels)
+ self.flatten1 = paddle.nn.Flatten(start_axis=0, stop_axis=2)
+ self.fc1 = paddle.nn.Linear(
+ in_features=in_channels, out_features=1, bias_attr=False)
+
+ def forward(self, inputs, encoder_word_pos, gsrm_word_pos):
+ b, c, h, w = inputs.shape
+ conv_features = paddle.reshape(inputs, shape=[-1, c, h * w])
+ conv_features = paddle.transpose(conv_features, perm=[0, 2, 1])
+ # transformer encoder
+ b, t, c = conv_features.shape
+
+ enc_inputs = [conv_features, encoder_word_pos, None]
+ word_features = self.wrap_encoder_for_feature(enc_inputs)
+
+ # pvam
+ b, t, c = word_features.shape
+ word_features = self.fc0(word_features)
+ word_features_ = paddle.reshape(word_features, [-1, 1, t, c])
+ word_features_ = paddle.tile(word_features_, [1, self.max_length, 1, 1])
+ word_pos_feature = self.emb(gsrm_word_pos)
+ word_pos_feature_ = paddle.reshape(word_pos_feature,
+ [-1, self.max_length, 1, c])
+ word_pos_feature_ = paddle.tile(word_pos_feature_, [1, 1, t, 1])
+ y = word_pos_feature_ + word_features_
+ y = F.tanh(y)
+ attention_weight = self.fc1(y)
+ attention_weight = paddle.reshape(
+ attention_weight, shape=[-1, self.max_length, t])
+ attention_weight = F.softmax(attention_weight, axis=-1)
+ pvam_features = paddle.matmul(attention_weight,
+ word_features) #[b, max_length, c]
+ return pvam_features
+
+
+class GSRM(nn.Layer):
+ def __init__(self, in_channels, char_num, max_text_length, num_heads,
+ num_encoder_tus, num_decoder_tus, hidden_dims):
+ super(GSRM, self).__init__()
+ self.char_num = char_num
+ self.max_length = max_text_length
+ self.num_heads = num_heads
+ self.num_encoder_TUs = num_encoder_tus
+ self.num_decoder_TUs = num_decoder_tus
+ self.hidden_dims = hidden_dims
+
+ self.fc0 = paddle.nn.Linear(
+ in_features=in_channels, out_features=self.char_num)
+ self.wrap_encoder0 = WrapEncoder(
+ src_vocab_size=self.char_num + 1,
+ max_length=self.max_length,
+ n_layer=self.num_decoder_TUs,
+ n_head=self.num_heads,
+ d_key=int(self.hidden_dims / self.num_heads),
+ d_value=int(self.hidden_dims / self.num_heads),
+ d_model=self.hidden_dims,
+ d_inner_hid=self.hidden_dims,
+ prepostprocess_dropout=0.1,
+ attention_dropout=0.1,
+ relu_dropout=0.1,
+ preprocess_cmd="n",
+ postprocess_cmd="da",
+ weight_sharing=True)
+
+ self.wrap_encoder1 = WrapEncoder(
+ src_vocab_size=self.char_num + 1,
+ max_length=self.max_length,
+ n_layer=self.num_decoder_TUs,
+ n_head=self.num_heads,
+ d_key=int(self.hidden_dims / self.num_heads),
+ d_value=int(self.hidden_dims / self.num_heads),
+ d_model=self.hidden_dims,
+ d_inner_hid=self.hidden_dims,
+ prepostprocess_dropout=0.1,
+ attention_dropout=0.1,
+ relu_dropout=0.1,
+ preprocess_cmd="n",
+ postprocess_cmd="da",
+ weight_sharing=True)
+
+ self.mul = lambda x: paddle.matmul(x=x,
+ y=self.wrap_encoder0.prepare_decoder.emb0.weight,
+ transpose_y=True)
+
+ def forward(self, inputs, gsrm_word_pos, gsrm_slf_attn_bias1,
+ gsrm_slf_attn_bias2):
+ # ===== GSRM Visual-to-semantic embedding block =====
+ b, t, c = inputs.shape
+ pvam_features = paddle.reshape(inputs, [-1, c])
+ word_out = self.fc0(pvam_features)
+ word_ids = paddle.argmax(F.softmax(word_out), axis=1)
+ word_ids = paddle.reshape(x=word_ids, shape=[-1, t, 1])
+
+ #===== GSRM Semantic reasoning block =====
+ """
+ This module is achieved through bi-transformers,
+ ngram_feature1 is the froward one, ngram_fetaure2 is the backward one
+ """
+ pad_idx = self.char_num
+
+ word1 = paddle.cast(word_ids, "float32")
+ word1 = F.pad(word1, [1, 0], value=1.0 * pad_idx, data_format="NLC")
+ word1 = paddle.cast(word1, "int64")
+ word1 = word1[:, :-1, :]
+ word2 = word_ids
+
+ enc_inputs_1 = [word1, gsrm_word_pos, gsrm_slf_attn_bias1]
+ enc_inputs_2 = [word2, gsrm_word_pos, gsrm_slf_attn_bias2]
+
+ gsrm_feature1 = self.wrap_encoder0(enc_inputs_1)
+ gsrm_feature2 = self.wrap_encoder1(enc_inputs_2)
+
+ gsrm_feature2 = F.pad(gsrm_feature2, [0, 1],
+ value=0.,
+ data_format="NLC")
+ gsrm_feature2 = gsrm_feature2[:, 1:, ]
+ gsrm_features = gsrm_feature1 + gsrm_feature2
+
+ gsrm_out = self.mul(gsrm_features)
+
+ b, t, c = gsrm_out.shape
+ gsrm_out = paddle.reshape(gsrm_out, [-1, c])
+
+ return gsrm_features, word_out, gsrm_out
+
+
+class VSFD(nn.Layer):
+ def __init__(self, in_channels=512, pvam_ch=512, char_num=38):
+ super(VSFD, self).__init__()
+ self.char_num = char_num
+ self.fc0 = paddle.nn.Linear(
+ in_features=in_channels * 2, out_features=pvam_ch)
+ self.fc1 = paddle.nn.Linear(
+ in_features=pvam_ch, out_features=self.char_num)
+
+ def forward(self, pvam_feature, gsrm_feature):
+ b, t, c1 = pvam_feature.shape
+ b, t, c2 = gsrm_feature.shape
+ combine_feature_ = paddle.concat([pvam_feature, gsrm_feature], axis=2)
+ img_comb_feature_ = paddle.reshape(
+ combine_feature_, shape=[-1, c1 + c2])
+ img_comb_feature_map = self.fc0(img_comb_feature_)
+ img_comb_feature_map = F.sigmoid(img_comb_feature_map)
+ img_comb_feature_map = paddle.reshape(
+ img_comb_feature_map, shape=[-1, t, c1])
+ combine_feature = img_comb_feature_map * pvam_feature + (
+ 1.0 - img_comb_feature_map) * gsrm_feature
+ img_comb_feature = paddle.reshape(combine_feature, shape=[-1, c1])
+
+ out = self.fc1(img_comb_feature)
+ return out
+
+
+class SRNHead(nn.Layer):
+ def __init__(self, in_channels, out_channels, max_text_length, num_heads,
+ num_encoder_TUs, num_decoder_TUs, hidden_dims, **kwargs):
+ super(SRNHead, self).__init__()
+ self.char_num = out_channels
+ self.max_length = max_text_length
+ self.num_heads = num_heads
+ self.num_encoder_TUs = num_encoder_TUs
+ self.num_decoder_TUs = num_decoder_TUs
+ self.hidden_dims = hidden_dims
+
+ self.pvam = PVAM(
+ in_channels=in_channels,
+ char_num=self.char_num,
+ max_text_length=self.max_length,
+ num_heads=self.num_heads,
+ num_encoder_tus=self.num_encoder_TUs,
+ hidden_dims=self.hidden_dims)
+
+ self.gsrm = GSRM(
+ in_channels=in_channels,
+ char_num=self.char_num,
+ max_text_length=self.max_length,
+ num_heads=self.num_heads,
+ num_encoder_tus=self.num_encoder_TUs,
+ num_decoder_tus=self.num_decoder_TUs,
+ hidden_dims=self.hidden_dims)
+ self.vsfd = VSFD(in_channels=in_channels)
+
+ self.gsrm.wrap_encoder1.prepare_decoder.emb0 = self.gsrm.wrap_encoder0.prepare_decoder.emb0
+
+ def forward(self, inputs, others):
+ encoder_word_pos = others[0]
+ gsrm_word_pos = others[1]
+ gsrm_slf_attn_bias1 = others[2]
+ gsrm_slf_attn_bias2 = others[3]
+
+ pvam_feature = self.pvam(inputs, encoder_word_pos, gsrm_word_pos)
+
+ gsrm_feature, word_out, gsrm_out = self.gsrm(
+ pvam_feature, gsrm_word_pos, gsrm_slf_attn_bias1,
+ gsrm_slf_attn_bias2)
+
+ final_out = self.vsfd(pvam_feature, gsrm_feature)
+ if not self.training:
+ final_out = F.softmax(final_out, axis=1)
+
+ _, decoded_out = paddle.topk(final_out, k=1)
+
+ predicts = OrderedDict([
+ ('predict', final_out),
+ ('pvam_feature', pvam_feature),
+ ('decoded_out', decoded_out),
+ ('word_out', word_out),
+ ('gsrm_out', gsrm_out),
+ ])
+
+ return predicts
diff --git a/ppocr/modeling/heads/self_attention.py b/ppocr/modeling/heads/self_attention.py
new file mode 100644
index 00000000..51d5198f
--- /dev/null
+++ b/ppocr/modeling/heads/self_attention.py
@@ -0,0 +1,409 @@
+# copyright (c) 2020 PaddlePaddle Authors. All Rights Reserve.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import math
+
+import paddle
+from paddle import ParamAttr, nn
+from paddle import nn, ParamAttr
+from paddle.nn import functional as F
+import paddle.fluid as fluid
+import numpy as np
+gradient_clip = 10
+
+
+class WrapEncoderForFeature(nn.Layer):
+ def __init__(self,
+ src_vocab_size,
+ max_length,
+ n_layer,
+ n_head,
+ d_key,
+ d_value,
+ d_model,
+ d_inner_hid,
+ prepostprocess_dropout,
+ attention_dropout,
+ relu_dropout,
+ preprocess_cmd,
+ postprocess_cmd,
+ weight_sharing,
+ bos_idx=0):
+ super(WrapEncoderForFeature, self).__init__()
+
+ self.prepare_encoder = PrepareEncoder(
+ src_vocab_size,
+ d_model,
+ max_length,
+ prepostprocess_dropout,
+ bos_idx=bos_idx,
+ word_emb_param_name="src_word_emb_table")
+ self.encoder = Encoder(n_layer, n_head, d_key, d_value, d_model,
+ d_inner_hid, prepostprocess_dropout,
+ attention_dropout, relu_dropout, preprocess_cmd,
+ postprocess_cmd)
+
+ def forward(self, enc_inputs):
+ conv_features, src_pos, src_slf_attn_bias = enc_inputs
+ enc_input = self.prepare_encoder(conv_features, src_pos)
+ enc_output = self.encoder(enc_input, src_slf_attn_bias)
+ return enc_output
+
+
+class WrapEncoder(nn.Layer):
+ """
+ embedder + encoder
+ """
+
+ def __init__(self,
+ src_vocab_size,
+ max_length,
+ n_layer,
+ n_head,
+ d_key,
+ d_value,
+ d_model,
+ d_inner_hid,
+ prepostprocess_dropout,
+ attention_dropout,
+ relu_dropout,
+ preprocess_cmd,
+ postprocess_cmd,
+ weight_sharing,
+ bos_idx=0):
+ super(WrapEncoder, self).__init__()
+
+ self.prepare_decoder = PrepareDecoder(
+ src_vocab_size,
+ d_model,
+ max_length,
+ prepostprocess_dropout,
+ bos_idx=bos_idx)
+ self.encoder = Encoder(n_layer, n_head, d_key, d_value, d_model,
+ d_inner_hid, prepostprocess_dropout,
+ attention_dropout, relu_dropout, preprocess_cmd,
+ postprocess_cmd)
+
+ def forward(self, enc_inputs):
+ src_word, src_pos, src_slf_attn_bias = enc_inputs
+ enc_input = self.prepare_decoder(src_word, src_pos)
+ enc_output = self.encoder(enc_input, src_slf_attn_bias)
+ return enc_output
+
+
+class Encoder(nn.Layer):
+ """
+ encoder
+ """
+
+ def __init__(self,
+ n_layer,
+ n_head,
+ d_key,
+ d_value,
+ d_model,
+ d_inner_hid,
+ prepostprocess_dropout,
+ attention_dropout,
+ relu_dropout,
+ preprocess_cmd="n",
+ postprocess_cmd="da"):
+
+ super(Encoder, self).__init__()
+
+ self.encoder_layers = list()
+ for i in range(n_layer):
+ self.encoder_layers.append(
+ self.add_sublayer(
+ "layer_%d" % i,
+ EncoderLayer(n_head, d_key, d_value, d_model, d_inner_hid,
+ prepostprocess_dropout, attention_dropout,
+ relu_dropout, preprocess_cmd,
+ postprocess_cmd)))
+ self.processer = PrePostProcessLayer(preprocess_cmd, d_model,
+ prepostprocess_dropout)
+
+ def forward(self, enc_input, attn_bias):
+ for encoder_layer in self.encoder_layers:
+ enc_output = encoder_layer(enc_input, attn_bias)
+ enc_input = enc_output
+ enc_output = self.processer(enc_output)
+ return enc_output
+
+
+class EncoderLayer(nn.Layer):
+ """
+ EncoderLayer
+ """
+
+ def __init__(self,
+ n_head,
+ d_key,
+ d_value,
+ d_model,
+ d_inner_hid,
+ prepostprocess_dropout,
+ attention_dropout,
+ relu_dropout,
+ preprocess_cmd="n",
+ postprocess_cmd="da"):
+
+ super(EncoderLayer, self).__init__()
+ self.preprocesser1 = PrePostProcessLayer(preprocess_cmd, d_model,
+ prepostprocess_dropout)
+ self.self_attn = MultiHeadAttention(d_key, d_value, d_model, n_head,
+ attention_dropout)
+ self.postprocesser1 = PrePostProcessLayer(postprocess_cmd, d_model,
+ prepostprocess_dropout)
+
+ self.preprocesser2 = PrePostProcessLayer(preprocess_cmd, d_model,
+ prepostprocess_dropout)
+ self.ffn = FFN(d_inner_hid, d_model, relu_dropout)
+ self.postprocesser2 = PrePostProcessLayer(postprocess_cmd, d_model,
+ prepostprocess_dropout)
+
+ def forward(self, enc_input, attn_bias):
+ attn_output = self.self_attn(
+ self.preprocesser1(enc_input), None, None, attn_bias)
+ attn_output = self.postprocesser1(attn_output, enc_input)
+ ffn_output = self.ffn(self.preprocesser2(attn_output))
+ ffn_output = self.postprocesser2(ffn_output, attn_output)
+ return ffn_output
+
+
+class MultiHeadAttention(nn.Layer):
+ """
+ Multi-Head Attention
+ """
+
+ def __init__(self, d_key, d_value, d_model, n_head=1, dropout_rate=0.):
+ super(MultiHeadAttention, self).__init__()
+ self.n_head = n_head
+ self.d_key = d_key
+ self.d_value = d_value
+ self.d_model = d_model
+ self.dropout_rate = dropout_rate
+ self.q_fc = paddle.nn.Linear(
+ in_features=d_model, out_features=d_key * n_head, bias_attr=False)
+ self.k_fc = paddle.nn.Linear(
+ in_features=d_model, out_features=d_key * n_head, bias_attr=False)
+ self.v_fc = paddle.nn.Linear(
+ in_features=d_model, out_features=d_value * n_head, bias_attr=False)
+ self.proj_fc = paddle.nn.Linear(
+ in_features=d_value * n_head, out_features=d_model, bias_attr=False)
+
+ def _prepare_qkv(self, queries, keys, values, cache=None):
+ if keys is None: # self-attention
+ keys, values = queries, queries
+ static_kv = False
+ else: # cross-attention
+ static_kv = True
+
+ q = self.q_fc(queries)
+ q = paddle.reshape(x=q, shape=[0, 0, self.n_head, self.d_key])
+ q = paddle.transpose(x=q, perm=[0, 2, 1, 3])
+
+ if cache is not None and static_kv and "static_k" in cache:
+ # for encoder-decoder attention in inference and has cached
+ k = cache["static_k"]
+ v = cache["static_v"]
+ else:
+ k = self.k_fc(keys)
+ v = self.v_fc(values)
+ k = paddle.reshape(x=k, shape=[0, 0, self.n_head, self.d_key])
+ k = paddle.transpose(x=k, perm=[0, 2, 1, 3])
+ v = paddle.reshape(x=v, shape=[0, 0, self.n_head, self.d_value])
+ v = paddle.transpose(x=v, perm=[0, 2, 1, 3])
+
+ if cache is not None:
+ if static_kv and not "static_k" in cache:
+ # for encoder-decoder attention in inference and has not cached
+ cache["static_k"], cache["static_v"] = k, v
+ elif not static_kv:
+ # for decoder self-attention in inference
+ cache_k, cache_v = cache["k"], cache["v"]
+ k = paddle.concat([cache_k, k], axis=2)
+ v = paddle.concat([cache_v, v], axis=2)
+ cache["k"], cache["v"] = k, v
+
+ return q, k, v
+
+ def forward(self, queries, keys, values, attn_bias, cache=None):
+ # compute q ,k ,v
+ keys = queries if keys is None else keys
+ values = keys if values is None else values
+ q, k, v = self._prepare_qkv(queries, keys, values, cache)
+
+ # scale dot product attention
+ product = paddle.matmul(x=q, y=k, transpose_y=True)
+ product = product * self.d_model**-0.5
+ if attn_bias is not None:
+ product += attn_bias
+ weights = F.softmax(product)
+ if self.dropout_rate:
+ weights = F.dropout(
+ weights, p=self.dropout_rate, mode="downscale_in_infer")
+ out = paddle.matmul(weights, v)
+
+ # combine heads
+ out = paddle.transpose(out, perm=[0, 2, 1, 3])
+ out = paddle.reshape(x=out, shape=[0, 0, out.shape[2] * out.shape[3]])
+
+ # project to output
+ out = self.proj_fc(out)
+
+ return out
+
+
+class PrePostProcessLayer(nn.Layer):
+ """
+ PrePostProcessLayer
+ """
+
+ def __init__(self, process_cmd, d_model, dropout_rate):
+ super(PrePostProcessLayer, self).__init__()
+ self.process_cmd = process_cmd
+ self.functors = []
+ for cmd in self.process_cmd:
+ if cmd == "a": # add residual connection
+ self.functors.append(lambda x, y: x + y if y is not None else x)
+ elif cmd == "n": # add layer normalization
+ self.functors.append(
+ self.add_sublayer(
+ "layer_norm_%d" % len(
+ self.sublayers(include_sublayers=False)),
+ paddle.nn.LayerNorm(
+ normalized_shape=d_model,
+ weight_attr=fluid.ParamAttr(
+ initializer=fluid.initializer.Constant(1.)),
+ bias_attr=fluid.ParamAttr(
+ initializer=fluid.initializer.Constant(0.)))))
+ elif cmd == "d": # add dropout
+ self.functors.append(lambda x: F.dropout(
+ x, p=dropout_rate, mode="downscale_in_infer")
+ if dropout_rate else x)
+
+ def forward(self, x, residual=None):
+ for i, cmd in enumerate(self.process_cmd):
+ if cmd == "a":
+ x = self.functors[i](x, residual)
+ else:
+ x = self.functors[i](x)
+ return x
+
+
+class PrepareEncoder(nn.Layer):
+ def __init__(self,
+ src_vocab_size,
+ src_emb_dim,
+ src_max_len,
+ dropout_rate=0,
+ bos_idx=0,
+ word_emb_param_name=None,
+ pos_enc_param_name=None):
+ super(PrepareEncoder, self).__init__()
+ self.src_emb_dim = src_emb_dim
+ self.src_max_len = src_max_len
+ self.emb = paddle.nn.Embedding(
+ num_embeddings=self.src_max_len,
+ embedding_dim=self.src_emb_dim,
+ sparse=True)
+ self.dropout_rate = dropout_rate
+
+ def forward(self, src_word, src_pos):
+ src_word_emb = src_word
+ src_word_emb = fluid.layers.cast(src_word_emb, 'float32')
+ src_word_emb = paddle.scale(x=src_word_emb, scale=self.src_emb_dim**0.5)
+ src_pos = paddle.squeeze(src_pos, axis=-1)
+ src_pos_enc = self.emb(src_pos)
+ src_pos_enc.stop_gradient = True
+ enc_input = src_word_emb + src_pos_enc
+ if self.dropout_rate:
+ out = F.dropout(
+ x=enc_input, p=self.dropout_rate, mode="downscale_in_infer")
+ else:
+ out = enc_input
+ return out
+
+
+class PrepareDecoder(nn.Layer):
+ def __init__(self,
+ src_vocab_size,
+ src_emb_dim,
+ src_max_len,
+ dropout_rate=0,
+ bos_idx=0,
+ word_emb_param_name=None,
+ pos_enc_param_name=None):
+ super(PrepareDecoder, self).__init__()
+ self.src_emb_dim = src_emb_dim
+ """
+ self.emb0 = Embedding(num_embeddings=src_vocab_size,
+ embedding_dim=src_emb_dim)
+ """
+ self.emb0 = paddle.nn.Embedding(
+ num_embeddings=src_vocab_size,
+ embedding_dim=self.src_emb_dim,
+ padding_idx=bos_idx,
+ weight_attr=paddle.ParamAttr(
+ name=word_emb_param_name,
+ initializer=nn.initializer.Normal(0., src_emb_dim**-0.5)))
+ self.emb1 = paddle.nn.Embedding(
+ num_embeddings=src_max_len,
+ embedding_dim=self.src_emb_dim,
+ weight_attr=paddle.ParamAttr(name=pos_enc_param_name))
+ self.dropout_rate = dropout_rate
+
+ def forward(self, src_word, src_pos):
+ src_word = fluid.layers.cast(src_word, 'int64')
+ src_word = paddle.squeeze(src_word, axis=-1)
+ src_word_emb = self.emb0(src_word)
+ src_word_emb = paddle.scale(x=src_word_emb, scale=self.src_emb_dim**0.5)
+ src_pos = paddle.squeeze(src_pos, axis=-1)
+ src_pos_enc = self.emb1(src_pos)
+ src_pos_enc.stop_gradient = True
+ enc_input = src_word_emb + src_pos_enc
+ if self.dropout_rate:
+ out = F.dropout(
+ x=enc_input, p=self.dropout_rate, mode="downscale_in_infer")
+ else:
+ out = enc_input
+ return out
+
+
+class FFN(nn.Layer):
+ """
+ Feed-Forward Network
+ """
+
+ def __init__(self, d_inner_hid, d_model, dropout_rate):
+ super(FFN, self).__init__()
+ self.dropout_rate = dropout_rate
+ self.fc1 = paddle.nn.Linear(
+ in_features=d_model, out_features=d_inner_hid)
+ self.fc2 = paddle.nn.Linear(
+ in_features=d_inner_hid, out_features=d_model)
+
+ def forward(self, x):
+ hidden = self.fc1(x)
+ hidden = F.relu(hidden)
+ if self.dropout_rate:
+ hidden = F.dropout(
+ hidden, p=self.dropout_rate, mode="downscale_in_infer")
+ out = self.fc2(hidden)
+ return out
diff --git a/ppocr/postprocess/__init__.py b/ppocr/postprocess/__init__.py
index c9b42e08..0156e438 100644
--- a/ppocr/postprocess/__init__.py
+++ b/ppocr/postprocess/__init__.py
@@ -26,11 +26,12 @@ def build_post_process(config, global_config=None):
from .db_postprocess import DBPostProcess
from .east_postprocess import EASTPostProcess
from .sast_postprocess import SASTPostProcess
- from .rec_postprocess import CTCLabelDecode, AttnLabelDecode
+ from .rec_postprocess import CTCLabelDecode, AttnLabelDecode, SRNLabelDecode
from .cls_postprocess import ClsPostProcess
support_dict = [
- 'DBPostProcess', 'EASTPostProcess', 'SASTPostProcess', 'CTCLabelDecode', 'AttnLabelDecode', 'ClsPostProcess'
+ 'DBPostProcess', 'EASTPostProcess', 'SASTPostProcess', 'CTCLabelDecode',
+ 'AttnLabelDecode', 'ClsPostProcess', 'SRNLabelDecode'
]
config = copy.deepcopy(config)
diff --git a/ppocr/postprocess/rec_postprocess.py b/ppocr/postprocess/rec_postprocess.py
index 4d078994..76a700e1 100644
--- a/ppocr/postprocess/rec_postprocess.py
+++ b/ppocr/postprocess/rec_postprocess.py
@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
+import string
import paddle
from paddle.nn import functional as F
@@ -24,19 +25,28 @@ class BaseRecLabelDecode(object):
character_type='ch',
use_space_char=False):
support_character_type = [
- 'ch', 'en', 'en_sensitive', 'french', 'german', 'japan', 'korean', 'it',
- 'xi', 'pu', 'ru', 'ar', 'ta', 'ug', 'fa', 'ur', 'rs', 'oc', 'rsc', 'bg',
- 'uk', 'be', 'te', 'ka', 'chinese_cht', 'hi', 'mr', 'ne'
+ 'ch', 'en', 'EN_symbol', 'french', 'german', 'japan', 'korean',
+ 'it', 'xi', 'pu', 'ru', 'ar', 'ta', 'ug', 'fa', 'ur', 'rs', 'oc',
+ 'rsc', 'bg', 'uk', 'be', 'te', 'ka', 'chinese_cht', 'hi', 'mr',
+ 'ne', 'EN'
]
assert character_type in support_character_type, "Only {} are supported now but get {}".format(
support_character_type, character_type)
+ self.beg_str = "sos"
+ self.end_str = "eos"
+
if character_type == "en":
self.character_str = "0123456789abcdefghijklmnopqrstuvwxyz"
dict_character = list(self.character_str)
- elif character_type in ["ch", "french", "german", "japan", "korean"]:
+ elif character_type == "EN_symbol":
+ # same with ASTER setting (use 94 char).
+ self.character_str = string.printable[:-6]
+ dict_character = list(self.character_str)
+ elif character_type in support_character_type:
self.character_str = ""
- assert character_dict_path is not None, "character_dict_path should not be None when character_type is ch"
+ assert character_dict_path is not None, "character_dict_path should not be None when character_type is {}".format(
+ character_type)
with open(character_dict_path, "rb") as fin:
lines = fin.readlines()
for line in lines:
@@ -45,11 +55,7 @@ class BaseRecLabelDecode(object):
if use_space_char:
self.character_str += " "
dict_character = list(self.character_str)
- elif character_type == "en_sensitive":
- # same with ASTER setting (use 94 char).
- import string
- self.character_str = string.printable[:-6]
- dict_character = list(self.character_str)
+
else:
raise NotImplementedError
self.character_type = character_type
@@ -106,7 +112,6 @@ class CTCLabelDecode(BaseRecLabelDecode):
def __call__(self, preds, label=None, *args, **kwargs):
if isinstance(preds, paddle.Tensor):
preds = preds.numpy()
-
preds_idx = preds.argmax(axis=2)
preds_prob = preds.max(axis=2)
text = self.decode(preds_idx, preds_prob, is_remove_duplicate=True)
@@ -155,3 +160,84 @@ class AttnLabelDecode(BaseRecLabelDecode):
assert False, "unsupport type %s in get_beg_end_flag_idx" \
% beg_or_end
return idx
+
+
+class SRNLabelDecode(BaseRecLabelDecode):
+ """ Convert between text-label and text-index """
+
+ def __init__(self,
+ character_dict_path=None,
+ character_type='en',
+ use_space_char=False,
+ **kwargs):
+ super(SRNLabelDecode, self).__init__(character_dict_path,
+ character_type, use_space_char)
+
+ def __call__(self, preds, label=None, *args, **kwargs):
+ pred = preds['predict']
+ char_num = len(self.character_str) + 2
+ if isinstance(pred, paddle.Tensor):
+ pred = pred.numpy()
+ pred = np.reshape(pred, [-1, char_num])
+
+ preds_idx = np.argmax(pred, axis=1)
+ preds_prob = np.max(pred, axis=1)
+
+ preds_idx = np.reshape(preds_idx, [-1, 25])
+
+ preds_prob = np.reshape(preds_prob, [-1, 25])
+
+ text = self.decode(preds_idx, preds_prob)
+
+ if label is None:
+ text = self.decode(preds_idx, preds_prob, is_remove_duplicate=False)
+ return text
+ label = self.decode(label)
+ return text, label
+
+ def decode(self, text_index, text_prob=None, is_remove_duplicate=False):
+ """ convert text-index into text-label. """
+ result_list = []
+ ignored_tokens = self.get_ignored_tokens()
+ batch_size = len(text_index)
+
+ for batch_idx in range(batch_size):
+ char_list = []
+ conf_list = []
+ for idx in range(len(text_index[batch_idx])):
+ if text_index[batch_idx][idx] in ignored_tokens:
+ continue
+ if is_remove_duplicate:
+ # only for predict
+ if idx > 0 and text_index[batch_idx][idx - 1] == text_index[
+ batch_idx][idx]:
+ continue
+ char_list.append(self.character[int(text_index[batch_idx][
+ idx])])
+ if text_prob is not None:
+ conf_list.append(text_prob[batch_idx][idx])
+ else:
+ conf_list.append(1)
+
+ text = ''.join(char_list)
+ result_list.append((text, np.mean(conf_list)))
+ return result_list
+
+ def add_special_char(self, dict_character):
+ dict_character = dict_character + [self.beg_str, self.end_str]
+ return dict_character
+
+ def get_ignored_tokens(self):
+ beg_idx = self.get_beg_end_flag_idx("beg")
+ end_idx = self.get_beg_end_flag_idx("end")
+ return [beg_idx, end_idx]
+
+ def get_beg_end_flag_idx(self, beg_or_end):
+ if beg_or_end == "beg":
+ idx = np.array(self.dict[self.beg_str])
+ elif beg_or_end == "end":
+ idx = np.array(self.dict[self.end_str])
+ else:
+ assert False, "unsupport type %s in get_beg_end_flag_idx" \
+ % beg_or_end
+ return idx
diff --git a/tools/export_model.py b/tools/export_model.py
index 74357d58..1e9526e0 100755
--- a/tools/export_model.py
+++ b/tools/export_model.py
@@ -31,6 +31,14 @@ from ppocr.utils.logging import get_logger
from tools.program import load_config, merge_config, ArgsParser
+def parse_args():
+ parser = argparse.ArgumentParser()
+ parser.add_argument("-c", "--config", help="configuration file to use")
+ parser.add_argument(
+ "-o", "--output_path", type=str, default='./output/infer/')
+ return parser.parse_args()
+
+
def main():
FLAGS = ArgsParser().parse_args()
config = load_config(FLAGS.config)
@@ -51,14 +59,40 @@ def main():
model.eval()
save_path = '{}/inference'.format(config['Global']['save_inference_dir'])
- infer_shape = [3, 32, 100] if config['Architecture'][
- 'model_type'] != "det" else [3, 640, 640]
- model = to_static(
- model,
- input_spec=[
+
+ if config['Architecture']['algorithm'] == "SRN":
+ other_shape = [
paddle.static.InputSpec(
- shape=[None] + infer_shape, dtype='float32')
- ])
+ shape=[None, 1, 64, 256], dtype='float32'), [
+ paddle.static.InputSpec(
+ shape=[None, 256, 1],
+ dtype="int64"), paddle.static.InputSpec(
+ shape=[None, 25, 1],
+ dtype="int64"), paddle.static.InputSpec(
+ shape=[None, 8, 25, 25], dtype="int64"),
+ paddle.static.InputSpec(
+ shape=[None, 8, 25, 25], dtype="int64")
+ ]
+ ]
+ model = to_static(model, input_spec=other_shape)
+ else:
+ infer_shape = [3, -1, -1]
+ if config['Architecture']['model_type'] == "rec":
+ infer_shape = [3, 32, -1] # for rec model, H must be 32
+ if 'Transform' in config['Architecture'] and config['Architecture'][
+ 'Transform'] is not None and config['Architecture'][
+ 'Transform']['name'] == 'TPS':
+ logger.info(
+ 'When there is tps in the network, variable length input is not supported, and the input size needs to be the same as during training'
+ )
+ infer_shape[-1] = 100
+ model = to_static(
+ model,
+ input_spec=[
+ paddle.static.InputSpec(
+ shape=[None] + infer_shape, dtype='float32')
+ ])
+
paddle.jit.save(model, save_path)
logger.info('inference model is saved to {}'.format(save_path))
diff --git a/tools/infer/predict_rec.py b/tools/infer/predict_rec.py
index 974fdbb6..fd895e50 100755
--- a/tools/infer/predict_rec.py
+++ b/tools/infer/predict_rec.py
@@ -25,6 +25,7 @@ import numpy as np
import math
import time
import traceback
+import paddle
import tools.infer.utility as utility
from ppocr.postprocess import build_post_process
@@ -46,6 +47,13 @@ class TextRecognizer(object):
"character_dict_path": args.rec_char_dict_path,
"use_space_char": args.use_space_char
}
+ if self.rec_algorithm == "SRN":
+ postprocess_params = {
+ 'name': 'SRNLabelDecode',
+ "character_type": args.rec_char_type,
+ "character_dict_path": args.rec_char_dict_path,
+ "use_space_char": args.use_space_char
+ }
self.postprocess_op = build_post_process(postprocess_params)
self.predictor, self.input_tensor, self.output_tensors = \
utility.create_predictor(args, 'rec', logger)
@@ -70,6 +78,78 @@ class TextRecognizer(object):
padding_im[:, :, 0:resized_w] = resized_image
return padding_im
+ def resize_norm_img_srn(self, img, image_shape):
+ imgC, imgH, imgW = image_shape
+
+ img_black = np.zeros((imgH, imgW))
+ im_hei = img.shape[0]
+ im_wid = img.shape[1]
+
+ if im_wid <= im_hei * 1:
+ img_new = cv2.resize(img, (imgH * 1, imgH))
+ elif im_wid <= im_hei * 2:
+ img_new = cv2.resize(img, (imgH * 2, imgH))
+ elif im_wid <= im_hei * 3:
+ img_new = cv2.resize(img, (imgH * 3, imgH))
+ else:
+ img_new = cv2.resize(img, (imgW, imgH))
+
+ img_np = np.asarray(img_new)
+ img_np = cv2.cvtColor(img_np, cv2.COLOR_BGR2GRAY)
+ img_black[:, 0:img_np.shape[1]] = img_np
+ img_black = img_black[:, :, np.newaxis]
+
+ row, col, c = img_black.shape
+ c = 1
+
+ return np.reshape(img_black, (c, row, col)).astype(np.float32)
+
+ def srn_other_inputs(self, image_shape, num_heads, max_text_length):
+
+ imgC, imgH, imgW = image_shape
+ feature_dim = int((imgH / 8) * (imgW / 8))
+
+ encoder_word_pos = np.array(range(0, feature_dim)).reshape(
+ (feature_dim, 1)).astype('int64')
+ gsrm_word_pos = np.array(range(0, max_text_length)).reshape(
+ (max_text_length, 1)).astype('int64')
+
+ gsrm_attn_bias_data = np.ones((1, max_text_length, max_text_length))
+ gsrm_slf_attn_bias1 = np.triu(gsrm_attn_bias_data, 1).reshape(
+ [-1, 1, max_text_length, max_text_length])
+ gsrm_slf_attn_bias1 = np.tile(
+ gsrm_slf_attn_bias1,
+ [1, num_heads, 1, 1]).astype('float32') * [-1e9]
+
+ gsrm_slf_attn_bias2 = np.tril(gsrm_attn_bias_data, -1).reshape(
+ [-1, 1, max_text_length, max_text_length])
+ gsrm_slf_attn_bias2 = np.tile(
+ gsrm_slf_attn_bias2,
+ [1, num_heads, 1, 1]).astype('float32') * [-1e9]
+
+ encoder_word_pos = encoder_word_pos[np.newaxis, :]
+ gsrm_word_pos = gsrm_word_pos[np.newaxis, :]
+
+ return [
+ encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1,
+ gsrm_slf_attn_bias2
+ ]
+
+ def process_image_srn(self, img, image_shape, num_heads, max_text_length):
+ norm_img = self.resize_norm_img_srn(img, image_shape)
+ norm_img = norm_img[np.newaxis, :]
+
+ [encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1, gsrm_slf_attn_bias2] = \
+ self.srn_other_inputs(image_shape, num_heads, max_text_length)
+
+ gsrm_slf_attn_bias1 = gsrm_slf_attn_bias1.astype(np.float32)
+ gsrm_slf_attn_bias2 = gsrm_slf_attn_bias2.astype(np.float32)
+ encoder_word_pos = encoder_word_pos.astype(np.int64)
+ gsrm_word_pos = gsrm_word_pos.astype(np.int64)
+
+ return (norm_img, encoder_word_pos, gsrm_word_pos, gsrm_slf_attn_bias1,
+ gsrm_slf_attn_bias2)
+
def __call__(self, img_list):
img_num = len(img_list)
# Calculate the aspect ratio of all text bars
@@ -93,21 +173,64 @@ class TextRecognizer(object):
wh_ratio = w * 1.0 / h
max_wh_ratio = max(max_wh_ratio, wh_ratio)
for ino in range(beg_img_no, end_img_no):
- # norm_img = self.resize_norm_img(img_list[ino], max_wh_ratio)
- norm_img = self.resize_norm_img(img_list[indices[ino]],
- max_wh_ratio)
- norm_img = norm_img[np.newaxis, :]
- norm_img_batch.append(norm_img)
+ if self.rec_algorithm != "SRN":
+ norm_img = self.resize_norm_img(img_list[indices[ino]],
+ max_wh_ratio)
+ norm_img = norm_img[np.newaxis, :]
+ norm_img_batch.append(norm_img)
+ else:
+ norm_img = self.process_image_srn(
+ img_list[indices[ino]], self.rec_image_shape, 8, 25)
+ encoder_word_pos_list = []
+ gsrm_word_pos_list = []
+ gsrm_slf_attn_bias1_list = []
+ gsrm_slf_attn_bias2_list = []
+ encoder_word_pos_list.append(norm_img[1])
+ gsrm_word_pos_list.append(norm_img[2])
+ gsrm_slf_attn_bias1_list.append(norm_img[3])
+ gsrm_slf_attn_bias2_list.append(norm_img[4])
+ norm_img_batch.append(norm_img[0])
norm_img_batch = np.concatenate(norm_img_batch)
norm_img_batch = norm_img_batch.copy()
- starttime = time.time()
- self.input_tensor.copy_from_cpu(norm_img_batch)
- self.predictor.run()
- outputs = []
- for output_tensor in self.output_tensors:
- output = output_tensor.copy_to_cpu()
- outputs.append(output)
- preds = outputs[0]
+
+ if self.rec_algorithm == "SRN":
+ starttime = time.time()
+ encoder_word_pos_list = np.concatenate(encoder_word_pos_list)
+ gsrm_word_pos_list = np.concatenate(gsrm_word_pos_list)
+ gsrm_slf_attn_bias1_list = np.concatenate(
+ gsrm_slf_attn_bias1_list)
+ gsrm_slf_attn_bias2_list = np.concatenate(
+ gsrm_slf_attn_bias2_list)
+
+ inputs = [
+ norm_img_batch,
+ encoder_word_pos_list,
+ gsrm_word_pos_list,
+ gsrm_slf_attn_bias1_list,
+ gsrm_slf_attn_bias2_list,
+ ]
+ input_names = self.predictor.get_input_names()
+ for i in range(len(input_names)):
+ input_tensor = self.predictor.get_input_handle(input_names[
+ i])
+ input_tensor.copy_from_cpu(inputs[i])
+ self.predictor.run()
+ outputs = []
+ for output_tensor in self.output_tensors:
+ output = output_tensor.copy_to_cpu()
+ outputs.append(output)
+ preds = {"predict": outputs[2]}
+ else:
+ starttime = time.time()
+ self.input_tensor.copy_from_cpu(norm_img_batch)
+ self.predictor.run()
+
+ outputs = []
+ for output_tensor in self.output_tensors:
+ output = output_tensor.copy_to_cpu()
+ outputs.append(output)
+ preds = outputs[0]
+
rec_result = self.postprocess_op(preds)
for rno in range(len(rec_result)):
rec_res[indices[beg_img_no + rno]] = rec_result[rno]
diff --git a/tools/infer/utility.py b/tools/infer/utility.py
index 966fa3cc..4171a29b 100755
--- a/tools/infer/utility.py
+++ b/tools/infer/utility.py
@@ -70,7 +70,7 @@ def parse_args():
default="./ppocr/utils/ppocr_keys_v1.txt")
parser.add_argument("--use_space_char", type=str2bool, default=True)
parser.add_argument(
- "--vis_font_path", type=str, default="./doc/simfang.ttf")
+ "--vis_font_path", type=str, default="./doc/fonts/simfang.ttf")
parser.add_argument("--drop_score", type=float, default=0.5)
# params for text classifier
diff --git a/tools/infer_rec.py b/tools/infer_rec.py
index 7e4b0811..075ec261 100755
--- a/tools/infer_rec.py
+++ b/tools/infer_rec.py
@@ -62,7 +62,13 @@ def main():
elif op_name in ['RecResizeImg']:
op[op_name]['infer_mode'] = True
elif op_name == 'KeepKeys':
- op[op_name]['keep_keys'] = ['image']
+ if config['Architecture']['algorithm'] == "SRN":
+ op[op_name]['keep_keys'] = [
+ 'image', 'encoder_word_pos', 'gsrm_word_pos',
+ 'gsrm_slf_attn_bias1', 'gsrm_slf_attn_bias2'
+ ]
+ else:
+ op[op_name]['keep_keys'] = ['image']
transforms.append(op)
global_config['infer_mode'] = True
ops = create_operators(transforms, global_config)
@@ -74,10 +80,25 @@ def main():
img = f.read()
data = {'image': img}
batch = transform(data, ops)
+ if config['Architecture']['algorithm'] == "SRN":
+ encoder_word_pos_list = np.expand_dims(batch[1], axis=0)
+ gsrm_word_pos_list = np.expand_dims(batch[2], axis=0)
+ gsrm_slf_attn_bias1_list = np.expand_dims(batch[3], axis=0)
+ gsrm_slf_attn_bias2_list = np.expand_dims(batch[4], axis=0)
+
+ others = [
+ paddle.to_tensor(encoder_word_pos_list),
+ paddle.to_tensor(gsrm_word_pos_list),
+ paddle.to_tensor(gsrm_slf_attn_bias1_list),
+ paddle.to_tensor(gsrm_slf_attn_bias2_list)
+ ]
images = np.expand_dims(batch[0], axis=0)
images = paddle.to_tensor(images)
- preds = model(images)
+ if config['Architecture']['algorithm'] == "SRN":
+ preds = model(images, others)
+ else:
+ preds = model(images)
post_result = post_process_class(preds)
for rec_reuslt in post_result:
logger.info('\t result: {}'.format(rec_reuslt))
diff --git a/tools/program.py b/tools/program.py
index c2915426..694d6415 100755
--- a/tools/program.py
+++ b/tools/program.py
@@ -174,6 +174,7 @@ def train(config,
best_model_dict = {main_indicator: 0}
best_model_dict.update(pre_best_model_dict)
train_stats = TrainingStats(log_smooth_window, ['lr'])
+ model_average = False
model.train()
if 'start_epoch' in best_model_dict:
@@ -182,8 +183,8 @@ def train(config,
start_epoch = 1
for epoch in range(start_epoch, epoch_num + 1):
- if epoch > 0:
- train_dataloader = build_dataloader(config, 'Train', device, logger)
+ train_dataloader = build_dataloader(
+ config, 'Train', device, logger, seed=epoch)
train_batch_cost = 0.0
train_reader_cost = 0.0
batch_sum = 0
@@ -194,7 +195,12 @@ def train(config,
break
lr = optimizer.get_lr()
images = batch[0]
- preds = model(images)
+ if config['Architecture']['algorithm'] == "SRN":
+ others = batch[-4:]
+ preds = model(images, others)
+ model_average = True
+ else:
+ preds = model(images)
loss = loss_class(preds, batch)
avg_loss = loss['loss']
avg_loss.backward()
@@ -212,7 +218,7 @@ def train(config,
stats['lr'] = lr
train_stats.update(stats)
- if cal_metric_during_train: # onlt rec and cls need
+ if cal_metric_during_train: # only rec and cls need
batch = [item.numpy() for item in batch]
post_result = post_process_class(preds, batch[1])
eval_class(post_result, batch)
@@ -238,21 +244,28 @@ def train(config,
# eval
if global_step > start_eval_step and \
(global_step - start_eval_step) % eval_batch_step == 0 and dist.get_rank() == 0:
+ if model_average:
+ Model_Average = paddle.incubate.optimizer.ModelAverage(
+ 0.15,
+ parameters=model.parameters(),
+ min_average_window=10000,
+ max_average_window=15625)
+ Model_Average.apply()
cur_metirc = eval(model, valid_dataloader, post_process_class,
eval_class)
- cur_metirc_str = 'cur metirc, {}'.format(', '.join(
- ['{}: {}'.format(k, v) for k, v in cur_metirc.items()]))
- logger.info(cur_metirc_str)
+ cur_metric_str = 'cur metric, {}'.format(', '.join(
+ ['{}: {}'.format(k, v) for k, v in cur_metric.items()]))
+ logger.info(cur_metric_str)
# logger metric
if vdl_writer is not None:
- for k, v in cur_metirc.items():
+ for k, v in cur_metric.items():
if isinstance(v, (float, int)):
vdl_writer.add_scalar('EVAL/{}'.format(k),
- cur_metirc[k], global_step)
- if cur_metirc[main_indicator] >= best_model_dict[
+ cur_metric[k], global_step)
+ if cur_metric[main_indicator] >= best_model_dict[
main_indicator]:
- best_model_dict.update(cur_metirc)
+ best_model_dict.update(cur_metric)
best_model_dict['best_epoch'] = epoch
save_model(
model,
@@ -263,7 +276,7 @@ def train(config,
prefix='best_accuracy',
best_model_dict=best_model_dict,
epoch=epoch)
- best_str = 'best metirc, {}'.format(', '.join([
+ best_str = 'best metric, {}'.format(', '.join([
'{}: {}'.format(k, v) for k, v in best_model_dict.items()
]))
logger.info(best_str)
@@ -273,6 +286,7 @@ def train(config,
best_model_dict[main_indicator],
global_step)
global_step += 1
+ optimizer.clear_grad()
batch_start = time.time()
if dist.get_rank() == 0:
save_model(
@@ -294,7 +308,7 @@ def train(config,
prefix='iter_epoch_{}'.format(epoch),
best_model_dict=best_model_dict,
epoch=epoch)
- best_str = 'best metirc, {}'.format(', '.join(
+ best_str = 'best metric, {}'.format(', '.join(
['{}: {}'.format(k, v) for k, v in best_model_dict.items()]))
logger.info(best_str)
if dist.get_rank() == 0 and vdl_writer is not None:
@@ -312,8 +326,9 @@ def eval(model, valid_dataloader, post_process_class, eval_class):
if idx >= len(valid_dataloader):
break
images = batch[0]
+ others = batch[-4:]
start = time.time()
- preds = model(images)
+ preds = model(images, others)
batch = [item.numpy() for item in batch]
# Obtain usable results from post-processing methods
@@ -323,13 +338,13 @@ def eval(model, valid_dataloader, post_process_class, eval_class):
eval_class(post_result, batch)
pbar.update(1)
total_frame += len(images)
- # Get final metirc,eg. acc or hmean
- metirc = eval_class.get_metric()
+ # Get final metric,eg. acc or hmean
+ metric = eval_class.get_metric()
pbar.close()
model.train()
- metirc['fps'] = total_frame / total_time
- return metirc
+ metric['fps'] = total_frame / total_time
+ return metric
def preprocess(is_train=False):
diff --git a/train.sh b/train.sh
index c511c516..8fe861a3 100644
--- a/train.sh
+++ b/train.sh
@@ -1,5 +1,2 @@
-# for paddle.__version__ >= 2.0rc1
+# recommended paddle.__version__ == 2.0.0
python3 -m paddle.distributed.launch --gpus '0,1,2,3,4,5,6,7' tools/train.py -c configs/rec/rec_mv3_none_bilstm_ctc.yml
-
-# for paddle.__version__ < 2.0rc1
-# python3 -m paddle.distributed.launch --selected_gpus '0,1,2,3,4,5,6,7' tools/train.py -c configs/rec/rec_mv3_none_bilstm_ctc.yml