forked from mindspore-Ecosystem/mindspore
add crnn
This commit is contained in:
parent
c70735c19f
commit
db674ee7ab
|
@ -17,7 +17,7 @@ Neural Networks Cells.
|
|||
|
||||
Pre-defined building blocks or computing units to construct Neural Networks.
|
||||
"""
|
||||
from . import layer, loss, optim, metrics, wrap, probability, sparse
|
||||
from . import layer, loss, optim, metrics, wrap, probability, sparse, dynamic_lr
|
||||
from .learning_rate_schedule import *
|
||||
from .dynamic_lr import *
|
||||
from .cell import Cell, GraphKernel
|
||||
|
|
|
@ -0,0 +1,244 @@
|
|||
# Contents
|
||||
|
||||
- [Contents](#contents)
|
||||
- [CRNN Description](#crnn-description)
|
||||
- [Model Architecture](#model-architecture)
|
||||
- [Dataset](#dataset)
|
||||
- [Dataset Prepare](#dataset-prepare)
|
||||
- [Environment Requirements](#environment-requirements)
|
||||
- [Quick Start](#quick-start)
|
||||
- [Script Description](#script-description)
|
||||
- [Script and Sample Code](#script-and-sample-code)
|
||||
- [Script Parameters](#script-parameters)
|
||||
- [Training Script Parameters](#training-script-parameters)
|
||||
- [Parameters Configuration](#parameters-configuration)
|
||||
- [Dataset Preparation](#dataset-preparation)
|
||||
- [Training Process](#training-process)
|
||||
- [Training](#training)
|
||||
- [Distributed Training](#distributed-training)
|
||||
- [Evaluation Process](#evaluation-process)
|
||||
- [Evaluation](#evaluation)
|
||||
- [Model Description](#model-description)
|
||||
- [Performance](#performance)
|
||||
- [Training Performance](#training-performance)
|
||||
- [Evaluation Performance](#evaluation-performance)
|
||||
- [Description of Random Situation](#description-of-random-situation)
|
||||
- [ModelZoo Homepage](#modelzoo-homepage)
|
||||
|
||||
## [CRNN Description](#contents)
|
||||
|
||||
CRNN was a neural network for image based sequence recognition and its Application to scene text recognition.In this paper, we investigate the problem of scene text recognition, which is among the most important and challenging tasks in image-based sequence recognition. A novel neural network architecture, which integrates feature extraction, sequence modeling and transcription into a unified framework, is proposed. Compared with previous systems for scene text recognition, the proposed architecture possesses four distinctive properties: (1) It is end-to-end trainable, in contrast to most of the existing algorithms whose components are separately trained and tuned. (2) It naturally handles sequences in arbitrary lengths, involving no character segmentation or horizontal scale normalization. (3) It is not confined to any predefined lexicon and achieves remarkable performances in both lexicon-free and lexicon-based scene text recognition tasks. (4) It generates an effective yet much smaller model, which is more practical for real-world application scenarios.
|
||||
|
||||
[Paper](https://arxiv.org/abs/1507.05717): Baoguang Shi, Xiang Bai, Cong Yao, "An End-to-End Trainable Neural Network for Image-based Sequence Recognition and Its Application to Scene Text Recognition", ArXiv, vol. abs/1507.05717, 2015.
|
||||
|
||||
## [Model Architecture](#content)
|
||||
|
||||
CRNN use a vgg16 structure for feature extraction, the appending with two-layer bidirectional LSTM, finally use CTC to calculate loss. See src/crnn.py for details.
|
||||
|
||||
## [Dataset](#content)
|
||||
|
||||
Note that you can run the scripts based on the dataset mentioned in original paper or widely used in relevant domain/network architecture. In the following sections, we will introduce how to run the scripts using the related dataset below.
|
||||
|
||||
We use five datasets mentioned in the paper.For training, we use the synthetic dataset([MJSynth](https://www.robots.ox.ac.uk/~vgg/data/text/) and [SynthText](https://github.com/ankush-me/SynthText)) released by Jaderberg etal as the training data, which contains 8 millions training images and their corresponding ground truth words.For evaluation, we use four popular benchmarks for scene text recognition, nalely ICDAR 2003([IC03](http://www.iapr-tc11.org/mediawiki/index.php?title=ICDAR_2003_Robust_Reading_Competitions)),ICDAR2013([IC13](https://rrc.cvc.uab.es/?ch=2&com=downloads)),IIIT 5k-word([IIIT5k](https://cvit.iiit.ac.in/research/projects/cvit-projects/the-iiit-5k-word-dataset)),and Street View Text([SVT](http://vision.ucsd.edu/~kai/grocr/)).
|
||||
|
||||
### [Dataset Prepare](#content)
|
||||
|
||||
For datset `IC03`, `IIIT5k` and `SVT`, the original dataset from the official website can not be used directly in CRNN.
|
||||
|
||||
- `IC03`, the text need to be cropped from the original image according to the words.xml.
|
||||
- `IIIT5k`, the annotation need to be extracted from the matlib data file.
|
||||
- `SVT`, the text need to be cropped from the original image according to the `train.xml` or `test.xml`.
|
||||
|
||||
We provide `convert_ic03.py`, `convert_iiit5k.py`, `convert_svt.py` as exmples for the aboving preprocessing which you can refer to.
|
||||
|
||||
## [Environment Requirements](#contents)
|
||||
|
||||
- Hardware(Ascend)
|
||||
- Prepare hardware environment with Ascend processor. If you want to try Ascend, please send the [application form](https://obs-9be7.obs.cn-east-2.myhuaweicloud.com/file/other/Ascend%20Model%20Zoo%E4%BD%93%E9%AA%8C%E8%B5%84%E6%BA%90%E7%94%B3%E8%AF%B7%E8%A1%A8.docx) to ascend@huawei.com. You will be able to have access to related resources once approved.
|
||||
- Framework
|
||||
- [MindSpore](https://gitee.com/mindspore/mindspore)
|
||||
- For more information, please check the resources below:
|
||||
- [MindSpore Tutorials](https://www.mindspore.cn/tutorial/training/en/master/index.html)
|
||||
- [MindSpore Python API](https://www.mindspore.cn/doc/api_python/en/master/index.html)
|
||||
|
||||
## [Quick Start](#contents)
|
||||
|
||||
- After the dataset is prepared, you may start running the training or the evaluation scripts as follows:
|
||||
|
||||
- Running on Ascend
|
||||
|
||||
```shell
|
||||
# distribute training example in Ascend
|
||||
$ bash run_distribute_train.sh [DATASET_NAME] [RANK_TABLE_FILE] [DATASET_PATH]
|
||||
|
||||
# evaluation example in Ascend
|
||||
$ bash run_eval.sh [DATASET_NAME] [DATASET_PATH] [CHECKPOINT_PATH] [PLATFORM]
|
||||
|
||||
# standalone training example in Ascend
|
||||
$ bash run_standalone_train.sh [DATASET_NAME] [DATASET_PATH] [PLATFORM]
|
||||
```
|
||||
|
||||
For distributed training, a hccl configuration file with JSON format needs to be created in advance.
|
||||
|
||||
Please follow the instructions in the link below:
|
||||
[hccl_tools](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/utils/hccl_tools).
|
||||
|
||||
## [Script Description](#contents)
|
||||
|
||||
### [Script and Sample Code](#contents)
|
||||
|
||||
```shell
|
||||
crnn
|
||||
├── README.md # Descriptions about CRNN
|
||||
├── convert_ic03.py # Convert the original IC03 daatset
|
||||
├── convert_iiit5k.py # Convert the original IIIT5K dataset
|
||||
├── convert_svt.py # Convert the original SVT dataset
|
||||
├── requirements.txt # Requirements for this dataset
|
||||
├── scripts
|
||||
│ ├── run_distribute_train.sh # Launch distributed training in Ascend(8 pcs)
|
||||
│ ├── run_eval.sh # Launch evaluation
|
||||
│ └── run_standalone_train.sh # Launch standalone training(1 pcs)
|
||||
├── src
|
||||
│ ├── config.py # Parameter configuration
|
||||
│ ├── crnn.py # crnn network definition
|
||||
│ ├── crnn_for_train.py # crnn network with grad, loss and gradient clip
|
||||
│ ├── dataset.py # Data preprocessing for training and evaluation
|
||||
│ ├── ic03_dataset.py # Data preprocessing for IC03
|
||||
│ ├── ic13_dataset.py # Data preprocessing for IC13
|
||||
│ ├── iiit5k_dataset.py # Data preprocessing for IIIT5K
|
||||
│ ├── loss.py # Ctcloss definition
|
||||
│ ├── metric.py # accuracy metric for crnn network
|
||||
│ └── svt_dataset.py # Data preprocessing for SVT
|
||||
└── train.py # Training script
|
||||
├── eval.py # Evaluation Script
|
||||
```
|
||||
|
||||
### [Script Parameters](#contents)
|
||||
|
||||
#### Training Script Parameters
|
||||
|
||||
```shell
|
||||
# distributed training in Ascend
|
||||
Usage: bash run_distribute_train.sh [DATASET_NAME] [RANK_TABLE_FILE] [DATASET_PATH]
|
||||
|
||||
# standalone training
|
||||
Usage: bash run_standalone_train.sh [DATASET_NAME] [DATASET_PATH] [PLATFORM]
|
||||
```
|
||||
|
||||
#### Parameters Configuration
|
||||
|
||||
Parameters for both training and evaluation can be set in config.py.
|
||||
|
||||
```shell
|
||||
max_text_length": 23, # max number of digits in each
|
||||
"image_width": 100, # width of text images
|
||||
"image_height": 32, # height of text images
|
||||
"batch_size": 64, # batch size of input tensor
|
||||
"epoch_size": 10, # only valid for taining, which is always 1
|
||||
"hidden_size": 256, # hidden size in LSTM layers
|
||||
"learning_rate": 0.02, # initial learning rate
|
||||
"momentum": 0.95, # momentum of SGD optimizer
|
||||
"nesterov": True, # enable nesterov in SGD optimizer
|
||||
"save_checkpoint": True, # whether save checkpoint or not
|
||||
"save_checkpoint_steps": 1000, # the step interval between two checkpoints.
|
||||
"keep_checkpoint_max": 30, # only keep the last keep_checkpoint_max
|
||||
"save_checkpoint_path": "./", # path to save checkpoint
|
||||
"class_num": 37, # dataset class num
|
||||
"input_size": 512, # input size for LSTM layer
|
||||
"num_step": 24, # num step for LSTM layer
|
||||
"use_dropout": True, # whether use dropout
|
||||
"blank": 36 # add blank for classification
|
||||
```
|
||||
|
||||
### [Dataset Preparation](#contents)
|
||||
|
||||
- You may refer to "Generate dataset" in [Quick Start](#quick-start) to automatically generate a dataset, or you may choose to generate a text image dataset by yourself.
|
||||
|
||||
## [Training Process](#contents)
|
||||
|
||||
- Set options in `config.py`, including learning rate and other network hyperparameters. Click [MindSpore dataset preparation tutorial](https://www.mindspore.cn/tutorial/training/zh-CN/master/use/data_preparation.html) for more information about dataset.
|
||||
|
||||
### [Training](#contents)
|
||||
|
||||
- Run `run_standalone_train.sh` for non-distributed training of CRNN model, either on Ascend or on GPU.
|
||||
|
||||
``` bash
|
||||
bash run_standalone_train.sh [DATASET_NAME] [DATASET_PATH] [PLATFORM]
|
||||
```
|
||||
|
||||
#### [Distributed Training](#contents)
|
||||
|
||||
- Run `run_distribute_train.sh` for distributed training of WarpCTC model on Ascend.
|
||||
|
||||
``` bash
|
||||
bash run_distribute_train.sh [DATASET_NAME] [RANK_TABLE_FILE] [DATASET_PATH]
|
||||
```
|
||||
|
||||
Check the `train_parallel0/log.txt` and you will get outputs as following:
|
||||
|
||||
```shell
|
||||
epoch: 10 step: 14110, loss is 0.0029097411
|
||||
Epoch time: 2743.688s, per step time: 0.097s
|
||||
```
|
||||
|
||||
## [Evaluation Process](#contents)
|
||||
|
||||
### [Evaluation](#contents)
|
||||
|
||||
- Run `run_eval.sh` for evaluation.
|
||||
|
||||
``` bash
|
||||
bash run_eval.sh [DATASET_NAME] [DATASET_PATH] [CHECKPOINT_PATH] [PLATFORM]
|
||||
```
|
||||
|
||||
Check the `eval/log.txt` and you will get outputs as following:
|
||||
|
||||
```shell
|
||||
result: {'CRNNAccuracy': (0.806)}
|
||||
```
|
||||
|
||||
## [Model Description](#contents)
|
||||
|
||||
### [Performance](#contents)
|
||||
|
||||
#### [Training Performance](#contents)
|
||||
|
||||
| Parameters | Ascend 910 |
|
||||
| -------------------------- | --------------------------------------------------|
|
||||
| Model Version | v1.0 |
|
||||
| Resource | Ascend 910, CPU 2.60GHz 192cores, Memory 755G |
|
||||
| uploaded Date | 12/15/2020 (month/day/year) |
|
||||
| MindSpore Version | 1.0.1 |
|
||||
| Dataset | Synth |
|
||||
| Training Parameters | epoch=10, steps per epoch=14110, batch_size = 64 |
|
||||
| Optimizer | SGD |
|
||||
| Loss Function | CTCLoss |
|
||||
| outputs | probability |
|
||||
| Loss | 0.0029097411 |
|
||||
| Speed | 118ms/step(8pcs) |
|
||||
| Total time | 557 mins |
|
||||
| Parameters (M) | 83M (.ckpt file) |
|
||||
| Checkpoint for Fine tuning | 20.3M (.ckpt file) |
|
||||
| Scripts | [Link](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/crnn) | [Link](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/crnn) |
|
||||
|
||||
#### [Evaluation Performance](#contents)
|
||||
|
||||
| Parameters | SVT | IIIT5K |
|
||||
| ------------------- | --------------------------- | --------------------------- |
|
||||
| Model Version | V1.0 | V1.0 |
|
||||
| Resource | Ascend 910 | Ascend 910 |
|
||||
| Uploaded Date | 12/15/2020 (month/day/year) | 12/15/2020 (month/day/year) |
|
||||
| MindSpore Version | 1.0.1 | 1.0.1 |
|
||||
| Dataset | SVT | IIIT5K |
|
||||
| batch_size | 1 | 1 |
|
||||
| outputs | ACC | ACC |
|
||||
| Accuracy | 80.9% | 80.6% |
|
||||
| Model for inference | 83M (.ckpt file) | 83M (.ckpt file) |
|
||||
|
||||
## [Description of Random Situation](#contents)
|
||||
|
||||
In dataset.py, we set the seed inside “create_dataset" function. We also use random seed in train.py for weight initialization.
|
||||
|
||||
## [ModelZoo Homepage](#contents)
|
||||
|
||||
Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo)
|
|
@ -0,0 +1,134 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import os
|
||||
import argparse
|
||||
from xml.etree import ElementTree as ET
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
|
||||
def init_args():
|
||||
parser = argparse.ArgumentParser('')
|
||||
parser.add_argument('-d', '--dataset_dir', type=str, default='./',
|
||||
help='path to original images')
|
||||
parser.add_argument('-x', '--xml_file', type=str, default='test.xml',
|
||||
help='Directory where character dictionaries for the dataset were stored')
|
||||
parser.add_argument('-o', '--output_dir', type=str, default='./processed',
|
||||
help='Directory where ord map dictionaries for the dataset were stored')
|
||||
|
||||
parser.add_argument('-a', '--output_annotation', type=str, default='./annotation.txt',
|
||||
help='Directory where ord map dictionaries for the dataset were stored')
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def xml_to_dict(xml_file, save_file=False):
|
||||
tree = ET.parse(xml_file)
|
||||
root = tree.getroot()
|
||||
imgs_labels = []
|
||||
|
||||
for ch in root:
|
||||
im_label = {}
|
||||
for ch01 in ch:
|
||||
if ch01.tag in 'taggedRectangles':
|
||||
# multiple children
|
||||
rect_list = []
|
||||
for ch02 in ch01:
|
||||
rect = {}
|
||||
rect['location'] = ch02.attrib
|
||||
rect['label'] = ch02[0].text
|
||||
rect_list.append(rect)
|
||||
im_label['rect'] = rect_list
|
||||
else:
|
||||
im_label[ch01.tag] = ch01.text
|
||||
imgs_labels.append(im_label)
|
||||
|
||||
if save_file:
|
||||
np.save("annotation_train.npy", imgs_labels)
|
||||
|
||||
return imgs_labels
|
||||
|
||||
|
||||
def image_crop_save(image, location, output_dir):
|
||||
"""
|
||||
crop image with location (h,w,x,y)
|
||||
save cropped image to output directory
|
||||
"""
|
||||
# avoid negative value of coordinates in annotation
|
||||
start_x = np.maximum(location[2], 0)
|
||||
end_x = start_x + location[1]
|
||||
start_y = np.maximum(location[3], 0)
|
||||
end_y = start_y + location[0]
|
||||
print("image array shape :{}".format(image.shape))
|
||||
print("crop region ", start_x, end_x, start_y, end_y)
|
||||
if len(image.shape) == 3:
|
||||
cropped = image[start_y:end_y, start_x:end_x, :]
|
||||
else:
|
||||
cropped = image[start_y:end_y, start_x:end_x]
|
||||
im = Image.fromarray(np.uint8(cropped))
|
||||
im.save(output_dir)
|
||||
|
||||
|
||||
def convert():
|
||||
args = init_args()
|
||||
if not os.path.exists(args.dataset_dir):
|
||||
raise ValueError("dataset_dir :{} does not exist".format(args.dataset_dir))
|
||||
|
||||
if not os.path.exists(args.xml_file):
|
||||
raise ValueError("xml_file :{} does not exist".format(args.xml_file))
|
||||
|
||||
if not os.path.exists(args.output_dir):
|
||||
os.makedirs(args.output_dir)
|
||||
|
||||
ims_labels_dict = xml_to_dict(args.xml_file, True)
|
||||
num_images = len(ims_labels_dict)
|
||||
annotation_list = []
|
||||
print("Converting annotation, {} images in total ".format(num_images))
|
||||
for i in range(num_images):
|
||||
img_label = ims_labels_dict[i]
|
||||
image_name = img_label['imageName']
|
||||
rects = img_label['rect']
|
||||
ext = image_name.split('.')[-1]
|
||||
name = image_name[:-(len(ext)+1)]
|
||||
|
||||
fullpath = os.path.join(args.dataset_dir, image_name)
|
||||
im_array = np.asarray(Image.open(fullpath))
|
||||
print("processing image: {}".format(image_name))
|
||||
for j, rect in enumerate(rects):
|
||||
location = rect['location']
|
||||
h = int(float(location['height']))
|
||||
w = int(float(location['width']))
|
||||
x = int(float(location['x']))
|
||||
y = int(float(location['y']))
|
||||
label = rect['label']
|
||||
loc = [h, w, x, y]
|
||||
output_name = name.replace("/", "_") + "_" + str(j) + "_" + label + '.' + ext
|
||||
output_name = output_name.replace(",", "")
|
||||
output_file = os.path.join(args.output_dir, output_name)
|
||||
|
||||
image_crop_save(im_array, loc, output_file)
|
||||
ann = output_name + "," + label + ','
|
||||
annotation_list.append(ann)
|
||||
|
||||
ann_file = args.output_annotation
|
||||
|
||||
with open(ann_file, 'w') as f:
|
||||
for line in annotation_list:
|
||||
txt = line + '\n'
|
||||
f.write(txt)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
convert()
|
|
@ -0,0 +1,67 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import argparse
|
||||
from scipy import io
|
||||
|
||||
###############################################
|
||||
# load testdata
|
||||
# testdata.mat structure
|
||||
# test[:][0] : image name
|
||||
# test[:][1] : label
|
||||
# test[:][2] : 50 lexicon
|
||||
# test[:][3] : 1000 lexicon
|
||||
##############################################
|
||||
|
||||
def init_args():
|
||||
parser = argparse.ArgumentParser('')
|
||||
parser.add_argument('-m', '--mat_file', type=str, default='testdata.mat',
|
||||
help='Directory where character dictionaries for the dataset were stored')
|
||||
parser.add_argument('-o', '--output_dir', type=str, default='./processed',
|
||||
help='Directory where ord map dictionaries for the dataset were stored')
|
||||
|
||||
parser.add_argument('-a', '--output_annotation', type=str, default='./annotation.txt',
|
||||
help='Directory where ord map dictionaries for the dataset were stored')
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def mat_to_list(mat_file):
|
||||
ann_ori = io.loadmat(mat_file)
|
||||
testdata = ann_ori['testdata'][0]
|
||||
|
||||
ann_output = []
|
||||
for elem in testdata:
|
||||
img_name = elem[0]
|
||||
label = elem[1]
|
||||
ann = img_name+',' +label
|
||||
ann_output.append(ann)
|
||||
return ann_output
|
||||
|
||||
|
||||
def convert():
|
||||
args = init_args()
|
||||
|
||||
ann_list = mat_to_list(args.mat_file)
|
||||
|
||||
ann_file = args.output_annotation
|
||||
with open(ann_file, 'w') as f:
|
||||
for line in ann_list:
|
||||
txt = line + '\n'
|
||||
f.write(txt)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
convert()
|
|
@ -0,0 +1,140 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
import os
|
||||
import argparse
|
||||
from xml.etree import ElementTree as ET
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
|
||||
|
||||
def init_args():
|
||||
parser = argparse.ArgumentParser('')
|
||||
parser.add_argument('-d', '--dataset_dir', type=str, default='./',
|
||||
help='Directory containing test_features.tfrecords')
|
||||
parser.add_argument('-x', '--xml_file', type=str, default='test.xml',
|
||||
help='Directory where character dictionaries for the dataset were stored')
|
||||
parser.add_argument('-o', '--output_dir', type=str, default='./processed',
|
||||
help='Directory where ord map dictionaries for the dataset were stored')
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def xml_to_dict(xml_file, save_file=False):
|
||||
tree = ET.parse(xml_file)
|
||||
root = tree.getroot()
|
||||
imgs_labels = []
|
||||
|
||||
for ch in root:
|
||||
im_label = {}
|
||||
for ch01 in ch:
|
||||
if ch01.tag in "address":
|
||||
continue
|
||||
elif ch01.tag in 'taggedRectangles':
|
||||
# multiple children
|
||||
rect_list = []
|
||||
for ch02 in ch01:
|
||||
rect = {}
|
||||
rect['location'] = ch02.attrib
|
||||
rect['label'] = ch02[0].text
|
||||
rect_list.append(rect)
|
||||
im_label['rect'] = rect_list
|
||||
else:
|
||||
im_label[ch01.tag] = ch01.text
|
||||
imgs_labels.append(im_label)
|
||||
|
||||
if save_file:
|
||||
np.save("annotation_train.npy", imgs_labels)
|
||||
|
||||
return imgs_labels
|
||||
|
||||
|
||||
def image_crop_save(image, location, output_dir):
|
||||
"""
|
||||
crop image with location (h,w,x,y)
|
||||
save cropped image to output directory
|
||||
"""
|
||||
start_x = location[2]
|
||||
end_x = start_x + location[1]
|
||||
start_y = location[3]
|
||||
if start_y < 0:
|
||||
start_y = 0
|
||||
end_y = start_y + location[0]
|
||||
print("image array shape :{}".format(image.shape))
|
||||
print("crop region ", start_x, end_x, start_y, end_y)
|
||||
if len(image.shape) == 3:
|
||||
cropped = image[start_y:end_y, start_x:end_x, :]
|
||||
else:
|
||||
cropped = image[start_y:end_y, start_x:end_x]
|
||||
im = Image.fromarray(np.uint8(cropped))
|
||||
im.save(output_dir)
|
||||
|
||||
|
||||
def convert():
|
||||
args = init_args()
|
||||
if not os.path.exists(args.dataset_dir):
|
||||
raise ValueError("dataset_dir :{} does not exist".format(args.dataset_dir))
|
||||
|
||||
if not os.path.exists(args.xml_file):
|
||||
raise ValueError("xml_file :{} does not exist".format(args.xml_file))
|
||||
|
||||
if not os.path.exists(args.output_dir):
|
||||
os.makedirs(args.output_dir)
|
||||
|
||||
ims_labels_dict = xml_to_dict(args.xml_file, True)
|
||||
num_images = len(ims_labels_dict)
|
||||
lexicon_list = []
|
||||
annotation_list = []
|
||||
print("Converting annotation, {} images in total ".format(num_images))
|
||||
for i in range(num_images):
|
||||
img_label = ims_labels_dict[i]
|
||||
image_name = img_label['imageName']
|
||||
lex = img_label['lex']
|
||||
rects = img_label['rect']
|
||||
name, ext = image_name.split('.')
|
||||
fullpath = os.path.join(args.dataset_dir, image_name)
|
||||
im_array = np.asarray(Image.open(fullpath))
|
||||
lexicon_list.append(lex)
|
||||
print("processing image: {}".format(image_name))
|
||||
for j, rect in enumerate(rects):
|
||||
rect = rects[j]
|
||||
location = rect['location']
|
||||
h = int(location['height'])
|
||||
w = int(location['width'])
|
||||
x = int(location['x'])
|
||||
y = int(location['y'])
|
||||
label = rect['label']
|
||||
loc = [h, w, x, y]
|
||||
output_name = name + "_" + str(j) + "_" + label + '.' + ext
|
||||
output_file = os.path.join(args.output_dir, output_name)
|
||||
image_crop_save(im_array, loc, output_file)
|
||||
ann = output_name + "," + label + ',' + str(i)
|
||||
annotation_list.append(ann)
|
||||
|
||||
lex_file = './lexicon_ann_train.txt'
|
||||
ann_file = './annotation_train.txt'
|
||||
with open(lex_file, 'w') as f:
|
||||
for line in lexicon_list:
|
||||
txt = line + '\n'
|
||||
f.write(txt)
|
||||
|
||||
with open(ann_file, 'w') as f:
|
||||
for line in annotation_list:
|
||||
txt = line + '\n'
|
||||
f.write(txt)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
convert()
|
|
@ -0,0 +1,72 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Warpctc evaluation"""
|
||||
import os
|
||||
import argparse
|
||||
from mindspore import context
|
||||
from mindspore.common import set_seed
|
||||
from mindspore.train.model import Model
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
|
||||
from src.loss import CTCLoss
|
||||
from src.dataset import create_dataset
|
||||
from src.crnn import CRNN
|
||||
from src.metric import CRNNAccuracy
|
||||
|
||||
set_seed(1)
|
||||
|
||||
parser = argparse.ArgumentParser(description="CRNN eval")
|
||||
parser.add_argument("--dataset_path", type=str, default=None, help="Dataset, default is None.")
|
||||
parser.add_argument("--checkpoint_path", type=str, default=None, help="checkpoint file path, default is None")
|
||||
parser.add_argument('--platform', type=str, default='Ascend', choices=['Ascend', 'GPU'],
|
||||
help='Running platform, choose from Ascend, GPU, and default is Ascend.')
|
||||
parser.add_argument('--model', type=str, default='lowcase', help="Model type, default is uppercase")
|
||||
parser.add_argument('--dataset', type=str, default='synth', choices=['synth', 'ic03', 'ic13', 'svt', 'iiit5k'])
|
||||
args_opt = parser.parse_args()
|
||||
|
||||
if args_opt.model == 'lowcase':
|
||||
from src.config import config1 as config
|
||||
else:
|
||||
from src.config import config2 as config
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.platform, save_graphs=False)
|
||||
if args_opt.platform == 'Ascend':
|
||||
device_id = int(os.getenv('DEVICE_ID'))
|
||||
context.set_context(device_id=device_id)
|
||||
|
||||
if __name__ == '__main__':
|
||||
config.batch_size = 1
|
||||
max_text_length = config.max_text_length
|
||||
input_size = config.input_size
|
||||
# create dataset
|
||||
dataset = create_dataset(name=args_opt.dataset,
|
||||
dataset_path=args_opt.dataset_path,
|
||||
batch_size=config.batch_size,
|
||||
is_training=False,
|
||||
config=config)
|
||||
step_size = dataset.get_dataset_size()
|
||||
loss = CTCLoss(max_sequence_length=config.num_step,
|
||||
max_label_length=max_text_length,
|
||||
batch_size=config.batch_size)
|
||||
net = CRNN(config)
|
||||
# load checkpoint
|
||||
param_dict = load_checkpoint(args_opt.checkpoint_path)
|
||||
load_param_into_net(net, param_dict)
|
||||
net.set_train(False)
|
||||
# define model
|
||||
model = Model(net, loss_fn=loss, metrics={'CRNNAccuracy': CRNNAccuracy(config)})
|
||||
# start evaluation
|
||||
res = model.eval(dataset, dataset_sink_mode=args_opt.platform == 'Ascend')
|
||||
print("result:", res, flush=True)
|
|
@ -0,0 +1 @@
|
|||
python-Levenshtein
|
|
@ -0,0 +1,62 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
if [ $# != 3 ]; then
|
||||
echo "Usage: sh run_distribute_train.sh [DATASET_NAME] [RANK_TABLE_FILE] [DATASET_PATH]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
get_real_path() {
|
||||
if [ "${1:0:1}" == "/" ]; then
|
||||
echo "$1"
|
||||
else
|
||||
echo "$(realpath -m $PWD/$1)"
|
||||
fi
|
||||
}
|
||||
|
||||
DATASET_NAME=$1
|
||||
PATH1=$(get_real_path $2)
|
||||
PATH2=$(get_real_path $3)
|
||||
|
||||
if [ ! -f $PATH1 ]; then
|
||||
echo "error: RANK_TABLE_FILE=$PATH1 is not a file"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ ! -d $PATH2 ]; then
|
||||
echo "error: DATASET_PATH=$PATH2 is not a directory"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
ulimit -u unlimited
|
||||
export DEVICE_NUM=8
|
||||
export RANK_SIZE=8
|
||||
export RANK_TABLE_FILE=$PATH1
|
||||
|
||||
for ((i = 0; i < ${DEVICE_NUM}; i++)); do
|
||||
export DEVICE_ID=$i
|
||||
export RANK_ID=$i
|
||||
rm -rf ./train_parallel$i
|
||||
mkdir ./train_parallel$i
|
||||
cp ../*.py ./train_parallel$i
|
||||
cp *.sh ./train_parallel$i
|
||||
cp -r ../src ./train_parallel$i
|
||||
cd ./train_parallel$i || exit
|
||||
echo "start training for rank $RANK_ID, device $DEVICE_ID"
|
||||
env >env.log
|
||||
python train.py --platform=Ascend --dataset_path=$PATH2 --run_distribute --dataset=$DATASET_NAME > log.txt 2>&1 &
|
||||
cd ..
|
||||
done
|
|
@ -0,0 +1,89 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
if [ $# != 4 ]; then
|
||||
echo "Usage: sh run_eval.sh [DATASET_NAME] [DATASET_PATH] [CHECKPOINT_PATH] [PLATFORM]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
get_real_path() {
|
||||
if [ "${1:0:1}" == "/" ]; then
|
||||
echo "$1"
|
||||
else
|
||||
echo "$(realpath -m $PWD/$1)"
|
||||
fi
|
||||
}
|
||||
|
||||
DATASET_NAME=$1
|
||||
PATH1=$(get_real_path $2)
|
||||
PATH2=$(get_real_path $3)
|
||||
PLATFORM=$4
|
||||
|
||||
if [ ! -d $PATH1 ]; then
|
||||
echo "error: DATASET_PATH=$PATH1 is not a directory"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ ! -f $PATH2 ]; then
|
||||
echo "error: CHECKPOINT_PATH=$PATH2 is not a file"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
run_ascend() {
|
||||
ulimit -u unlimited
|
||||
export DEVICE_NUM=1
|
||||
export DEVICE_ID=0
|
||||
export RANK_SIZE=$DEVICE_NUM
|
||||
export RANK_ID=0
|
||||
|
||||
if [ -d "eval" ]; then
|
||||
rm -rf ./eval
|
||||
fi
|
||||
mkdir ./eval
|
||||
cp ../*.py ./eval
|
||||
cp -r ../src ./eval
|
||||
cd ./eval || exit
|
||||
env >env.log
|
||||
echo "start evaluation for device $DEVICE_ID"
|
||||
python eval.py --dataset=$DATASET_NAME --dataset_path=$1 --checkpoint_path=$2 --platform=Ascend > log.txt 2>&1 &
|
||||
cd ..
|
||||
}
|
||||
|
||||
run_gpu() {
|
||||
if [ -d "eval" ]; then
|
||||
rm -rf ./eval
|
||||
fi
|
||||
mkdir ./eval
|
||||
cp ../*.py ./eval
|
||||
cp -r ../src ./eval
|
||||
cd ./eval || exit
|
||||
env >env.log
|
||||
python eval.py --dataset=$DATASET_NAME \
|
||||
--dataset_path=$1 \
|
||||
--checkpoint_path=$2 \
|
||||
--platform=GPU \
|
||||
--dataset=$DATASET_NAME > log.txt 2>&1 &
|
||||
cd ..
|
||||
}
|
||||
|
||||
if [ "Ascend" == $PLATFORM ]; then
|
||||
run_ascend $PATH1 $PATH2
|
||||
elif [ "GPU" == $PLATFORM ]; then
|
||||
run_gpu $PATH1 $PATH2
|
||||
else
|
||||
echo "error: PLATFORM=$PLATFORM is not support, only support Ascend and GPU."
|
||||
fi
|
||||
|
|
@ -0,0 +1,73 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
if [ $# != 3 ]; then
|
||||
echo "Usage: sh run_standalone_train.sh [DATASET_NAME] [DATASET_PATH] [PLATFORM]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
get_real_path() {
|
||||
if [ "${1:0:1}" == "/" ]; then
|
||||
echo "$1"
|
||||
else
|
||||
echo "$(realpath -m $PWD/$1)"
|
||||
fi
|
||||
}
|
||||
|
||||
DATASET_NAME=$1
|
||||
PATH1=$(get_real_path $2)
|
||||
PLATFORM=$3
|
||||
|
||||
if [ ! -d $PATH1 ]; then
|
||||
echo "error: DATASET_PATH=$PATH1 is not a directory"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
export DEVICE_ID=0
|
||||
run_ascend() {
|
||||
ulimit -u unlimited
|
||||
export DEVICE_NUM=1
|
||||
export RANK_ID=0
|
||||
export RANK_SIZE=1
|
||||
|
||||
echo "start training for device $DEVICE_ID"
|
||||
env >env.log
|
||||
python train.py --dataset=$DATASET_NAME --dataset_path=$1 --platform=Ascend > log.txt 2>&1 &
|
||||
cd ..
|
||||
}
|
||||
|
||||
run_gpu() {
|
||||
env >env.log
|
||||
python train.py --dataset=$DATASET_NAME --dataset_path=$1 --platform=GPU > log.txt 2>&1 &
|
||||
cd ..
|
||||
}
|
||||
|
||||
if [ -d "train" ]; then
|
||||
rm -rf ./train
|
||||
fi
|
||||
WORKDIR=./train$(DEVICE_ID)
|
||||
mkdir $WORKDIR
|
||||
cp ../*.py $WORKDIR
|
||||
cp -r ../src $WORKDIR
|
||||
cd $WORKDIR || exit
|
||||
|
||||
if [ "Ascend" == $PLATFORM ]; then
|
||||
run_ascend $PATH1
|
||||
elif [ "GPU" == $PLATFORM ]; then
|
||||
run_gpu $PATH1
|
||||
else
|
||||
echo "error: PLATFORM=$PLATFORM is not support, only support Ascend and GPU."
|
||||
fi
|
|
@ -0,0 +1,42 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Network parameters."""
|
||||
from easydict import EasyDict
|
||||
|
||||
|
||||
label_dict = "abcdefghijklmnopqrstuvwxyz0123456789"
|
||||
|
||||
|
||||
# use for low case number
|
||||
config1 = EasyDict({
|
||||
"max_text_length": 23,
|
||||
"image_width": 100,
|
||||
"image_height": 32,
|
||||
"batch_size": 64,
|
||||
"epoch_size": 10,
|
||||
"hidden_size": 256,
|
||||
"learning_rate": 0.02,
|
||||
"momentum": 0.95,
|
||||
"nesterov": True,
|
||||
"save_checkpoint": True,
|
||||
"save_checkpoint_steps": 1000,
|
||||
"keep_checkpoint_max": 30,
|
||||
"save_checkpoint_path": "./",
|
||||
"class_num": 37,
|
||||
"input_size": 512,
|
||||
"num_step": 24,
|
||||
"use_dropout": True,
|
||||
"blank": 36
|
||||
})
|
|
@ -0,0 +1,171 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Warpctc network definition."""
|
||||
|
||||
import numpy as np
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor, Parameter
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore.common.initializer import TruncatedNormal
|
||||
|
||||
def _bn(channel):
|
||||
return nn.BatchNorm2d(channel, eps=1e-4, momentum=0.9, gamma_init=1, beta_init=0, moving_mean_init=0,
|
||||
moving_var_init=1)
|
||||
|
||||
class Conv(nn.Cell):
|
||||
def __init__(self, in_channel, out_channel, kernel_size=3, stride=1, use_bn=False, pad_mode='same'):
|
||||
super(Conv, self).__init__()
|
||||
self.conv = nn.Conv2d(in_channel, out_channel, kernel_size=kernel_size, stride=stride,
|
||||
padding=0, pad_mode=pad_mode, weight_init=TruncatedNormal(0.02))
|
||||
self.bn = _bn(out_channel)
|
||||
self.Relu = nn.ReLU()
|
||||
self.use_bn = use_bn
|
||||
def construct(self, x):
|
||||
out = self.conv(x)
|
||||
if self.use_bn:
|
||||
out = self.bn(out)
|
||||
out = self.Relu(out)
|
||||
return out
|
||||
|
||||
class VGG(nn.Cell):
|
||||
"""VGG Network structure"""
|
||||
def __init__(self, is_training=True):
|
||||
super(VGG, self).__init__()
|
||||
self.conv1 = Conv(3, 64, use_bn=True)
|
||||
self.conv2 = Conv(64, 128, use_bn=True)
|
||||
self.conv3 = Conv(128, 256, use_bn=True)
|
||||
self.conv4 = Conv(256, 256, use_bn=True)
|
||||
self.conv5 = Conv(256, 512, use_bn=True)
|
||||
self.conv6 = Conv(512, 512, use_bn=True)
|
||||
self.conv7 = Conv(512, 512, kernel_size=2, pad_mode='valid', use_bn=True)
|
||||
self.maxpool2d1 = nn.MaxPool2d(kernel_size=2, stride=2, pad_mode='same')
|
||||
self.maxpool2d2 = nn.MaxPool2d(kernel_size=(2, 1), stride=(2, 1), pad_mode='same')
|
||||
self.bn1 = _bn(512)
|
||||
|
||||
def construct(self, x):
|
||||
x = self.conv1(x)
|
||||
x = self.maxpool2d1(x)
|
||||
x = self.conv2(x)
|
||||
x = self.maxpool2d1(x)
|
||||
x = self.conv3(x)
|
||||
x = self.conv4(x)
|
||||
x = self.maxpool2d2(x)
|
||||
x = self.conv5(x)
|
||||
x = self.conv6(x)
|
||||
x = self.maxpool2d2(x)
|
||||
x = self.conv7(x)
|
||||
return x
|
||||
|
||||
|
||||
class CRNN(nn.Cell):
|
||||
"""
|
||||
Define a CRNN network which contains Bidirectional LSTM layers and vgg layer.
|
||||
|
||||
Args:
|
||||
input_size(int): Size of time sequence. Usually, the input_size is equal to three times of image height for
|
||||
text images.
|
||||
batch_size(int): batch size of input data, default is 64
|
||||
hidden_size(int): the hidden size in LSTM layers, default is 512
|
||||
"""
|
||||
def __init__(self, config):
|
||||
super(CRNN, self).__init__()
|
||||
self.batch_size = config.batch_size
|
||||
self.input_size = config.input_size
|
||||
self.hidden_size = config.hidden_size
|
||||
self.num_classes = config.class_num
|
||||
self.reshape = P.Reshape()
|
||||
self.cast = P.Cast()
|
||||
k = (1 / self.hidden_size) ** 0.5
|
||||
self.rnn1 = P.DynamicRNN(forget_bias=0.0)
|
||||
self.rnn1_bw = P.DynamicRNN(forget_bias=0.0)
|
||||
self.rnn2 = P.DynamicRNN(forget_bias=0.0)
|
||||
self.rnn2_bw = P.DynamicRNN(forget_bias=0.0)
|
||||
|
||||
w1 = np.random.uniform(-k, k, (self.input_size + self.hidden_size, 4 * self.hidden_size))
|
||||
self.w1 = Parameter(w1.astype(np.float16), name="w1")
|
||||
w2 = np.random.uniform(-k, k, (2 * self.hidden_size + self.hidden_size, 4 * self.hidden_size))
|
||||
self.w2 = Parameter(w2.astype(np.float16), name="w2")
|
||||
w1_bw = np.random.uniform(-k, k, (self.input_size + self.hidden_size, 4 * self.hidden_size))
|
||||
self.w1_bw = Parameter(w1_bw.astype(np.float16), name="w1_bw")
|
||||
w2_bw = np.random.uniform(-k, k, (2 * self.hidden_size + self.hidden_size, 4 * self.hidden_size))
|
||||
self.w2_bw = Parameter(w2_bw.astype(np.float16), name="w2_bw")
|
||||
|
||||
self.b1 = Parameter(np.random.uniform(-k, k, (4 * self.hidden_size)).astype(np.float16), name="b1")
|
||||
self.b2 = Parameter(np.random.uniform(-k, k, (4 * self.hidden_size)).astype(np.float16), name="b2")
|
||||
self.b1_bw = Parameter(np.random.uniform(-k, k, (4 * self.hidden_size)).astype(np.float16), name="b1_bw")
|
||||
self.b2_bw = Parameter(np.random.uniform(-k, k, (4 * self.hidden_size)).astype(np.float16), name="b2_bw")
|
||||
|
||||
self.h1 = Tensor(np.zeros(shape=(1, self.batch_size, self.hidden_size)).astype(np.float16))
|
||||
self.h2 = Tensor(np.zeros(shape=(1, self.batch_size, self.hidden_size)).astype(np.float16))
|
||||
self.h1_bw = Tensor(np.zeros(shape=(1, self.batch_size, self.hidden_size)).astype(np.float16))
|
||||
self.h2_bw = Tensor(np.zeros(shape=(1, self.batch_size, self.hidden_size)).astype(np.float16))
|
||||
|
||||
self.c1 = Tensor(np.zeros(shape=(1, self.batch_size, self.hidden_size)).astype(np.float16))
|
||||
self.c2 = Tensor(np.zeros(shape=(1, self.batch_size, self.hidden_size)).astype(np.float16))
|
||||
self.c1_bw = Tensor(np.zeros(shape=(1, self.batch_size, self.hidden_size)).astype(np.float16))
|
||||
self.c2_bw = Tensor(np.zeros(shape=(1, self.batch_size, self.hidden_size)).astype(np.float16))
|
||||
|
||||
self.fc_weight = np.random.random((self.num_classes, self.hidden_size)).astype(np.float32)
|
||||
self.fc_bias = np.random.random((self.num_classes)).astype(np.float32)
|
||||
|
||||
self.fc = nn.Dense(in_channels=self.hidden_size, out_channels=self.num_classes,
|
||||
weight_init=Tensor(self.fc_weight), bias_init=Tensor(self.fc_bias))
|
||||
self.fc.to_float(mstype.float32)
|
||||
self.expand_dims = P.ExpandDims()
|
||||
self.concat = P.Concat()
|
||||
self.transpose = P.Transpose()
|
||||
self.squeeze = P.Squeeze(axis=0)
|
||||
self.vgg = VGG()
|
||||
self.reverse_seq1 = P.ReverseSequence(batch_dim=1, seq_dim=0)
|
||||
self.reverse_seq2 = P.ReverseSequence(batch_dim=1, seq_dim=0)
|
||||
self.reverse_seq3 = P.ReverseSequence(batch_dim=1, seq_dim=0)
|
||||
self.reverse_seq4 = P.ReverseSequence(batch_dim=1, seq_dim=0)
|
||||
self.seq_length = Tensor(np.ones((self.batch_size), np.int32) * config.num_step, mstype.int32)
|
||||
self.concat1 = P.Concat(axis=2)
|
||||
self.dropout = nn.Dropout(0.5)
|
||||
self.rnn_dropout = nn.Dropout(0.9)
|
||||
self.use_dropout = config.use_dropout
|
||||
|
||||
def construct(self, x):
|
||||
x = self.vgg(x)
|
||||
x = self.cast(x, mstype.float16)
|
||||
|
||||
x = self.reshape(x, (self.batch_size, self.input_size, -1))
|
||||
x = self.transpose(x, (2, 0, 1))
|
||||
bw_x = self.reverse_seq1(x, self.seq_length)
|
||||
y1, _, _, _, _, _, _, _ = self.rnn1(x, self.w1, self.b1, None, self.h1, self.c1)
|
||||
y1_bw, _, _, _, _, _, _, _ = self.rnn1_bw(bw_x, self.w1_bw, self.b1_bw, None, self.h1_bw, self.c1_bw)
|
||||
y1_bw = self.reverse_seq2(y1_bw, self.seq_length)
|
||||
y1_out = self.concat1((y1, y1_bw))
|
||||
if self.use_dropout:
|
||||
y1_out = self.rnn_dropout(y1_out)
|
||||
|
||||
y2, _, _, _, _, _, _, _ = self.rnn2(y1_out, self.w2, self.b2, None, self.h2, self.c2)
|
||||
bw_y = self.reverse_seq3(y1_out, self.seq_length)
|
||||
y2_bw, _, _, _, _, _, _, _ = self.rnn2(bw_y, self.w2_bw, self.b2_bw, None, self.h2_bw, self.c2_bw)
|
||||
y2_bw = self.reverse_seq4(y2_bw, self.seq_length)
|
||||
y2_out = self.concat1((y2, y2_bw))
|
||||
if self.use_dropout:
|
||||
y2_out = self.dropout(y2_out)
|
||||
|
||||
output = ()
|
||||
for i in range(F.shape(y2_out)[0]):
|
||||
y2_after_fc = self.fc(self.squeeze(y2[i:i+1:1]))
|
||||
y2_after_fc = self.expand_dims(y2_after_fc, 0)
|
||||
output += (y2_after_fc,)
|
||||
output = self.concat(output)
|
||||
return output
|
|
@ -0,0 +1,114 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Automatic differentiation with grad clip."""
|
||||
import numpy as np
|
||||
from mindspore.parallel._utils import (_get_device_num, _get_gradients_mean,
|
||||
_get_parallel_mode)
|
||||
from mindspore.context import ParallelMode
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.nn.cell import Cell
|
||||
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
|
||||
import mindspore.nn as nn
|
||||
from mindspore.common.tensor import Tensor
|
||||
|
||||
compute_norm = C.MultitypeFuncGraph("compute_norm")
|
||||
|
||||
|
||||
@compute_norm.register("Tensor")
|
||||
def _compute_norm(grad):
|
||||
norm = nn.Norm()
|
||||
norm = norm(F.cast(grad, mstype.float32))
|
||||
ret = F.expand_dims(F.cast(norm, mstype.float32), 0)
|
||||
return ret
|
||||
|
||||
|
||||
grad_div = C.MultitypeFuncGraph("grad_div")
|
||||
|
||||
|
||||
@grad_div.register("Tensor", "Tensor")
|
||||
def _grad_div(val, grad):
|
||||
div = P.RealDiv()
|
||||
mul = P.Mul()
|
||||
scale = div(10.0, val)
|
||||
ret = mul(grad, scale)
|
||||
return ret
|
||||
|
||||
|
||||
class TrainOneStepCellWithGradClip(Cell):
|
||||
"""
|
||||
Network training package class.
|
||||
|
||||
Wraps the network with an optimizer. The resulting Cell be trained with input data and label.
|
||||
Backward graph with grad clip will be created in the construct function to do parameter updating.
|
||||
Different parallel modes are available to run the training.
|
||||
|
||||
Args:
|
||||
network (Cell): The training network.
|
||||
optimizer (Cell): Optimizer for updating the weights.
|
||||
sens (Number): The scaling number to be filled as the input of backpropagation. Default value is 1.0.
|
||||
|
||||
Inputs:
|
||||
- data (Tensor) - Tensor of shape :(N, ...).
|
||||
- label (Tensor) - Tensor of shape :(N, ...).
|
||||
|
||||
Outputs:
|
||||
Tensor, a scalar Tensor with shape :math:`()`.
|
||||
"""
|
||||
|
||||
def __init__(self, network, optimizer, sens=1.0):
|
||||
super(TrainOneStepCellWithGradClip, self).__init__(auto_prefix=False)
|
||||
self.network = network
|
||||
self.network.set_grad()
|
||||
self.network.add_flags(defer_inline=True)
|
||||
self.weights = optimizer.parameters
|
||||
self.optimizer = optimizer
|
||||
self.grad = C.GradOperation(get_by_list=True, sens_param=True)
|
||||
self.sens = sens
|
||||
self.reducer_flag = False
|
||||
self.grad_reducer = None
|
||||
self.hyper_map = C.HyperMap()
|
||||
self.greater = P.Greater()
|
||||
self.select = P.Select()
|
||||
self.norm = nn.Norm(keep_dims=True)
|
||||
self.dtype = P.DType()
|
||||
self.cast = P.Cast()
|
||||
self.concat = P.Concat(axis=0)
|
||||
self.ten = Tensor(np.array([10.0]).astype(np.float32))
|
||||
parallel_mode = _get_parallel_mode()
|
||||
if parallel_mode in (ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL):
|
||||
self.reducer_flag = True
|
||||
if self.reducer_flag:
|
||||
mean = _get_gradients_mean()
|
||||
degree = _get_device_num()
|
||||
self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree)
|
||||
|
||||
def construct(self, data, label):
|
||||
weights = self.weights
|
||||
loss = self.network(data, label)
|
||||
sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens)
|
||||
grads = self.grad(self.network, weights)(data, label, sens)
|
||||
norm = self.hyper_map(F.partial(compute_norm), grads)
|
||||
norm = self.concat(norm)
|
||||
norm = self.norm(norm)
|
||||
cond = self.greater(norm, self.cast(self.ten, self.dtype(norm)))
|
||||
clip_val = self.select(cond, norm, self.cast(self.ten, self.dtype(norm)))
|
||||
grads = self.hyper_map(F.partial(grad_div, clip_val), grads)
|
||||
if self.reducer_flag:
|
||||
# apply grad reducer on grads
|
||||
grads = self.grad_reducer(grads)
|
||||
return F.depend(loss, self.optimizer(grads))
|
|
@ -0,0 +1,121 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Dataset preprocessing."""
|
||||
import os
|
||||
import numpy as np
|
||||
import mindspore.common.dtype as mstype
|
||||
import mindspore.dataset.engine as de
|
||||
import mindspore.dataset.transforms.c_transforms as C
|
||||
import mindspore.dataset.vision.c_transforms as vc
|
||||
from PIL import Image, ImageFile
|
||||
from src.config import config1, label_dict
|
||||
from src.ic03_dataset import IC03Dataset
|
||||
from src.ic13_dataset import IC13Dataset
|
||||
from src.iiit5k_dataset import IIIT5KDataset
|
||||
from src.svt_dataset import SVTDataset
|
||||
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
||||
|
||||
|
||||
class CaptchaDataset:
|
||||
"""
|
||||
create train or evaluation dataset for crnn
|
||||
|
||||
Args:
|
||||
img_root_dir(str): root path of images
|
||||
max_text_length(int): max number of digits in images.
|
||||
device_target(str): platform of training, support Ascend and GPU.
|
||||
"""
|
||||
|
||||
def __init__(self, img_root_dir, is_training=True, config=config1):
|
||||
if not os.path.exists(img_root_dir):
|
||||
raise RuntimeError("the input image dir {} is invalid!".format(img_root_dir))
|
||||
self.img_root_dir = img_root_dir
|
||||
if is_training:
|
||||
self.imgslist = os.path.join(self.img_root_dir, 'annotation_train.txt')
|
||||
else:
|
||||
self.imgslist = os.path.join(self.img_root_dir, 'annotation_test.txt')
|
||||
self.lexicon_file = os.path.join(self.img_root_dir, 'lexicon.txt')
|
||||
with open(self.lexicon_file, 'r') as f:
|
||||
self.lexicons = [line.strip('\n') for line in f]
|
||||
f.close()
|
||||
self.img_names = {}
|
||||
self.img_list = []
|
||||
with open(self.imgslist, 'r') as f:
|
||||
for line in f:
|
||||
img_name, label_index = line.strip('\n').split(" ")
|
||||
self.img_list.append(img_name)
|
||||
self.img_names[img_name] = self.lexicons[int(label_index)]
|
||||
f.close()
|
||||
self.max_text_length = config.max_text_length
|
||||
self.blank = config.blank
|
||||
self.class_num = config.class_num
|
||||
|
||||
def __len__(self):
|
||||
return len(self.img_names)
|
||||
|
||||
def __getitem__(self, item):
|
||||
img_name = self.img_list[item]
|
||||
im = Image.open(os.path.join(self.img_root_dir, img_name))
|
||||
im = im.convert("RGB")
|
||||
r, g, b = im.split()
|
||||
im = Image.merge("RGB", (b, g, r))
|
||||
image = np.array(im)
|
||||
label_str = self.img_names[img_name]
|
||||
label = []
|
||||
for c in label_str:
|
||||
if c in label_dict:
|
||||
label.append(label_dict.index(c))
|
||||
label.extend([int(self.blank)] * (self.max_text_length - len(label)))
|
||||
label = np.array(label)
|
||||
return image, label
|
||||
|
||||
|
||||
def create_dataset(name, dataset_path, batch_size=1, num_shards=1, shard_id=0, is_training=True, config=config1):
|
||||
"""
|
||||
create train or evaluation dataset for crnn
|
||||
|
||||
Args:
|
||||
dataset_path(int): dataset path
|
||||
batch_size(int): batch size of generated dataset, default is 1
|
||||
num_shards(int): number of devices
|
||||
shard_id(int): rank id
|
||||
device_target(str): platform of training, support Ascend and GPU
|
||||
"""
|
||||
if name == 'synth':
|
||||
dataset = CaptchaDataset(dataset_path, is_training, config)
|
||||
elif name == 'ic03':
|
||||
dataset = IC03Dataset(dataset_path, "annotation.txt", config, True, 3)
|
||||
elif name == 'ic13':
|
||||
dataset = IC13Dataset(dataset_path, "Challenge2_Test_Task3_GT.txt", config)
|
||||
elif name == 'svt':
|
||||
dataset = SVTDataset(dataset_path, config)
|
||||
elif name == 'iiit5k':
|
||||
dataset = IIIT5KDataset(dataset_path, "annotation.txt", config)
|
||||
else:
|
||||
raise ValueError(f"unsupported dataset name: {name}")
|
||||
ds = de.GeneratorDataset(dataset, ["image", "label"], shuffle=True, num_shards=num_shards, shard_id=shard_id)
|
||||
image_trans = [
|
||||
vc.Resize((config.image_height, config.image_width)),
|
||||
vc.Normalize([127.5, 127.5, 127.5], std=[127.5, 127.5, 127.5]),
|
||||
vc.HWC2CHW()
|
||||
]
|
||||
label_trans = [
|
||||
C.TypeCast(mstype.int32)
|
||||
]
|
||||
ds = ds.map(operations=image_trans, input_columns=["image"], num_parallel_workers=8)
|
||||
ds = ds.map(operations=label_trans, input_columns=["label"], num_parallel_workers=8)
|
||||
|
||||
ds = ds.batch(batch_size, drop_remainder=True)
|
||||
return ds
|
|
@ -0,0 +1,80 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Dataset adaptor for SVT"""
|
||||
import os
|
||||
import numpy as np
|
||||
from PIL import Image, ImageFile
|
||||
from src.config import config1, label_dict
|
||||
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
||||
|
||||
|
||||
class IC03Dataset:
|
||||
"""
|
||||
create train or evaluation dataset for crnn
|
||||
|
||||
Args:
|
||||
img_root_dir(str): root path of images
|
||||
max_text_length(int): max number of digits in images.
|
||||
device_target(str): platform of training, support Ascend and GPU.
|
||||
"""
|
||||
|
||||
def __init__(self, img_root_dir, anno_file="annotation.txt", config=config1, filter_by_dict=True, filter_length=3):
|
||||
if not os.path.exists(img_root_dir):
|
||||
raise RuntimeError("the input image dir {} is invalid!".format(img_root_dir))
|
||||
self.img_root_dir = img_root_dir
|
||||
anno_file = os.path.join(img_root_dir, anno_file)
|
||||
|
||||
self.img_names = {}
|
||||
self.img_list = []
|
||||
with open(anno_file, 'r') as f:
|
||||
for lines in f:
|
||||
img_name = lines.split(",")[0]
|
||||
label = lines.split(",")[1].lower()
|
||||
if len(label) < filter_length:
|
||||
continue
|
||||
if filter_by_dict:
|
||||
flag = True
|
||||
for c in label:
|
||||
if c not in label_dict:
|
||||
flag = False
|
||||
break
|
||||
if not flag:
|
||||
continue
|
||||
self.img_names[img_name] = label
|
||||
self.img_list.append(img_name)
|
||||
|
||||
self.max_text_length = config.max_text_length
|
||||
self.blank = config.blank
|
||||
self.class_num = config.class_num
|
||||
|
||||
def __len__(self):
|
||||
return len(self.img_names)
|
||||
|
||||
def __getitem__(self, item):
|
||||
img_name = self.img_list[item]
|
||||
im = Image.open(os.path.join(self.img_root_dir, img_name))
|
||||
im = im.convert("RGB")
|
||||
r, g, b = im.split()
|
||||
im = Image.merge("RGB", (b, g, r))
|
||||
image = np.array(im)
|
||||
label_str = self.img_names[img_name]
|
||||
label = []
|
||||
for c in label_str:
|
||||
if c in label_dict:
|
||||
label.append(label_dict.index(c))
|
||||
label.extend([int(self.blank)] * (self.max_text_length - len(label)))
|
||||
label = np.array(label)
|
||||
return image, label
|
|
@ -0,0 +1,77 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Dataset adaptor for SVT"""
|
||||
import os
|
||||
import numpy as np
|
||||
from PIL import Image, ImageFile
|
||||
from src.config import config1, label_dict
|
||||
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
||||
|
||||
|
||||
class IC13Dataset:
|
||||
"""
|
||||
create evaluation dataset for crnn
|
||||
|
||||
Args:
|
||||
img_root_dir(str): root path of images
|
||||
max_text_length(int): max number of digits in images
|
||||
device_target(str): platform of training, support Ascend and GPU
|
||||
"""
|
||||
def __init__(self, img_root_dir, label_file="", config=config1, filter_by_dict=True, filter_length=3):
|
||||
if not os.path.exists(img_root_dir):
|
||||
raise RuntimeError("the input image dir {} is invalid".format(img_root_dir))
|
||||
self.img_root_dir = img_root_dir
|
||||
self.label_file = os.path.join(img_root_dir, label_file)
|
||||
self.img_names = {}
|
||||
self.img_list = []
|
||||
self.config = config
|
||||
with open(self.label_file, 'r') as f:
|
||||
for lines in f:
|
||||
img_name = lines.split(",")[0]
|
||||
label = lines.split("\"")[1].lower()
|
||||
if len(label) < filter_length:
|
||||
continue
|
||||
if filter_by_dict:
|
||||
flag = True
|
||||
for c in label:
|
||||
if c not in label_dict:
|
||||
flag = False
|
||||
break
|
||||
if not flag:
|
||||
continue
|
||||
self.img_names[img_name] = label
|
||||
self.img_list.append(img_name)
|
||||
f.close()
|
||||
self.max_text_length = config.max_text_length
|
||||
self.blank = config.blank
|
||||
self.class_num = config.class_num
|
||||
def __len__(self):
|
||||
return len(self.img_names)
|
||||
def __getitem__(self, item):
|
||||
img_name = self.img_list[item]
|
||||
im = Image.open(os.path.join(self.img_root_dir, img_name))
|
||||
im = im.convert("RGB")
|
||||
r, g, b = im.split()
|
||||
im = Image.merge("RGB", (b, g, r))
|
||||
image = np.array(im)
|
||||
label_str = self.img_names[img_name]
|
||||
label = []
|
||||
for c in label_str:
|
||||
if c in label_dict:
|
||||
label.append(label_dict.index(c))
|
||||
label.extend([int(self.blank)] * (self.max_text_length - len(label)))
|
||||
label = np.array(label)
|
||||
return image, label
|
|
@ -0,0 +1,69 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Dataset adaptor for SVT"""
|
||||
import os
|
||||
import numpy as np
|
||||
from PIL import Image, ImageFile
|
||||
from src.config import config1, label_dict
|
||||
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
||||
|
||||
class IIIT5KDataset:
|
||||
"""
|
||||
create train or evaluation dataset for crnn
|
||||
|
||||
Args:
|
||||
img_root_dir(str): root path of images
|
||||
max_text_length(int): max number of digits in images.
|
||||
device_target(str): platform of training, support Ascend and GPU.
|
||||
"""
|
||||
|
||||
def __init__(self, img_root_dir, anno_file="annotation.txt", config=config1):
|
||||
if not os.path.exists(img_root_dir):
|
||||
raise RuntimeError("the input image dir {} is invalid!".format(img_root_dir))
|
||||
self.img_root_dir = img_root_dir
|
||||
anno_file = os.path.join(img_root_dir, anno_file)
|
||||
|
||||
self.img_names = {}
|
||||
self.img_list = []
|
||||
with open(anno_file, 'r') as f:
|
||||
for lines in f:
|
||||
img_name = lines.split(",")[0]
|
||||
label = lines.split(",")[1].lower()
|
||||
self.img_names[img_name] = label
|
||||
self.img_list.append(img_name)
|
||||
|
||||
self.max_text_length = config.max_text_length
|
||||
self.blank = config.blank
|
||||
self.class_num = config.class_num
|
||||
|
||||
def __len__(self):
|
||||
return len(self.img_names)
|
||||
|
||||
def __getitem__(self, item):
|
||||
img_name = self.img_list[item]
|
||||
im = Image.open(os.path.join(self.img_root_dir, img_name))
|
||||
im = im.convert("RGB")
|
||||
r, g, b = im.split()
|
||||
im = Image.merge("RGB", (b, g, r))
|
||||
image = np.array(im)
|
||||
label_str = self.img_names[img_name]
|
||||
label = []
|
||||
for c in label_str:
|
||||
if c in label_dict:
|
||||
label.append(label_dict.index(c))
|
||||
label.extend([int(self.blank)] * (self.max_text_length - len(label)))
|
||||
label = np.array(label)
|
||||
return image, label
|
|
@ -0,0 +1,49 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""CTC Loss."""
|
||||
import numpy as np
|
||||
from mindspore.nn.loss.loss import _Loss
|
||||
from mindspore import Tensor, Parameter
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
|
||||
class CTCLoss(_Loss):
|
||||
"""
|
||||
CTCLoss definition
|
||||
|
||||
Args:
|
||||
max_sequence_length(int): max number of sequence length. For text images, the value is equal to image
|
||||
width
|
||||
max_label_length(int): max number of label length for each input.
|
||||
batch_size(int): batch size of input logits
|
||||
"""
|
||||
|
||||
def __init__(self, max_sequence_length, max_label_length, batch_size):
|
||||
super(CTCLoss, self).__init__()
|
||||
self.sequence_length = Parameter(Tensor(np.array([max_sequence_length] * batch_size), mstype.int32),
|
||||
name="sequence_length")
|
||||
labels_indices = []
|
||||
for i in range(batch_size):
|
||||
for j in range(max_label_length):
|
||||
labels_indices.append([i, j])
|
||||
self.labels_indices = Parameter(Tensor(np.array(labels_indices), mstype.int64), name="labels_indices")
|
||||
self.reshape = P.Reshape()
|
||||
self.ctc_loss = P.CTCLoss(ctc_merge_repeated=True)
|
||||
|
||||
def construct(self, logit, label):
|
||||
labels_values = self.reshape(label, (-1,))
|
||||
loss, _ = self.ctc_loss(logit, self.labels_indices, labels_values, self.sequence_length)
|
||||
return loss
|
|
@ -0,0 +1,94 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Metric for accuracy evaluation."""
|
||||
from mindspore import nn
|
||||
import Levenshtein
|
||||
label_dict = "abcdefghijklmnopqrstuvwxyz0123456789"
|
||||
|
||||
class CRNNAccuracy(nn.Metric):
|
||||
"""
|
||||
Define accuracy metric for warpctc network.
|
||||
"""
|
||||
|
||||
def __init__(self, config):
|
||||
super(CRNNAccuracy).__init__()
|
||||
self.config = config
|
||||
self._correct_num = 0
|
||||
self._total_num = 0
|
||||
self.blank = config.blank
|
||||
|
||||
def clear(self):
|
||||
self._correct_num = 0
|
||||
self._total_num = 0
|
||||
|
||||
def update(self, *inputs):
|
||||
if len(inputs) != 2:
|
||||
raise ValueError('CRNNAccuracy need 2 inputs (y_pred, y), but got {}'.format(len(inputs)))
|
||||
y_pred = self._convert_data(inputs[0])
|
||||
y = self._convert_data(inputs[1])
|
||||
str_pred = self._ctc_greedy_decoder(y_pred)
|
||||
str_label = self._convert_labels(y)
|
||||
|
||||
for pred, label in zip(str_pred, str_label):
|
||||
print(pred, " :: ", label)
|
||||
edit_distance = Levenshtein.distance(pred, label)
|
||||
self._total_num += 1
|
||||
if edit_distance == 0:
|
||||
self._correct_num += 1
|
||||
|
||||
def eval(self):
|
||||
if self._total_num == 0:
|
||||
raise RuntimeError('Accuary can not be calculated, because the number of samples is 0.')
|
||||
print('correct num: ', self._correct_num, ', total num: ', self._total_num)
|
||||
sequence_accurancy = self._correct_num / self._total_num
|
||||
return sequence_accurancy
|
||||
|
||||
def _arr2char(self, inputs):
|
||||
string = ""
|
||||
for i in inputs:
|
||||
if i < self.blank:
|
||||
string += label_dict[i]
|
||||
return string
|
||||
|
||||
def _convert_labels(self, inputs):
|
||||
str_list = []
|
||||
for label in inputs:
|
||||
str_temp = self._arr2char(label)
|
||||
str_list.append(str_temp)
|
||||
return str_list
|
||||
|
||||
def _ctc_greedy_decoder(self, y_pred):
|
||||
"""
|
||||
parse predict result to labels
|
||||
"""
|
||||
indices = []
|
||||
seq_len, batch_size, _ = y_pred.shape
|
||||
indices = y_pred.argmax(axis=2)
|
||||
lens = [seq_len] * batch_size
|
||||
pred_labels = []
|
||||
for i in range(batch_size):
|
||||
idx = indices[:, i]
|
||||
last_idx = self.blank
|
||||
pred_label = []
|
||||
for j in range(lens[i]):
|
||||
cur_idx = idx[j]
|
||||
if cur_idx not in [last_idx, self.blank]:
|
||||
pred_label.append(cur_idx)
|
||||
last_idx = cur_idx
|
||||
pred_labels.append(pred_label)
|
||||
str_results = []
|
||||
for i in pred_labels:
|
||||
str_results.append(self._arr2char(i))
|
||||
return str_results
|
|
@ -0,0 +1,67 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""Dataset adaptor for SVT"""
|
||||
import os
|
||||
import numpy as np
|
||||
from PIL import Image, ImageFile
|
||||
from src.config import config1, label_dict
|
||||
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
||||
|
||||
class SVTDataset:
|
||||
"""
|
||||
create train or evaluation dataset for crnn
|
||||
|
||||
Args:
|
||||
img_root_dir(str): root path of images
|
||||
max_text_length(int): max number of digits in images.
|
||||
device_target(str): platform of training, support Ascend and GPU.
|
||||
"""
|
||||
|
||||
def __init__(self, img_root_dir, config=config1):
|
||||
if not os.path.exists(img_root_dir):
|
||||
raise RuntimeError("the input image dir {} is invalid!".format(img_root_dir))
|
||||
self.img_root_dir = img_root_dir
|
||||
file_list = os.listdir(img_root_dir)
|
||||
self.img_names = {}
|
||||
self.img_list = []
|
||||
for f in file_list:
|
||||
label = f.split(".jpg")[0]
|
||||
label = label.split("_")[-1].lower()
|
||||
self.img_names[f] = label
|
||||
self.img_list.append(f)
|
||||
|
||||
self.max_text_length = config.max_text_length
|
||||
self.blank = config.blank
|
||||
self.class_num = config.class_num
|
||||
|
||||
def __len__(self):
|
||||
return len(self.img_names)
|
||||
|
||||
def __getitem__(self, item):
|
||||
img_name = self.img_list[item]
|
||||
im = Image.open(os.path.join(self.img_root_dir, img_name))
|
||||
im = im.convert("RGB")
|
||||
r, g, b = im.split()
|
||||
im = Image.merge("RGB", (b, g, r))
|
||||
image = np.array(im)
|
||||
label_str = self.img_names[img_name]
|
||||
label = []
|
||||
for c in label_str:
|
||||
if c in label_dict:
|
||||
label.append(label_dict.index(c))
|
||||
label.extend([int(self.blank)] * (self.max_text_length - len(label)))
|
||||
label = np.array(label)
|
||||
return image, label
|
|
@ -0,0 +1,101 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""crnn training"""
|
||||
import os
|
||||
import argparse
|
||||
import mindspore.nn as nn
|
||||
from mindspore import context
|
||||
from mindspore.common import set_seed
|
||||
from mindspore.train.model import Model
|
||||
from mindspore.context import ParallelMode
|
||||
from mindspore.nn.wrap import WithLossCell
|
||||
from mindspore.train.callback import TimeMonitor, LossMonitor, CheckpointConfig, ModelCheckpoint
|
||||
from mindspore.communication.management import init, get_group_size, get_rank
|
||||
|
||||
from src.loss import CTCLoss
|
||||
from src.dataset import create_dataset
|
||||
from src.crnn import CRNN
|
||||
from src.crnn_for_train import TrainOneStepCellWithGradClip
|
||||
|
||||
set_seed(1)
|
||||
|
||||
parser = argparse.ArgumentParser(description="crnn training")
|
||||
parser.add_argument("--run_distribute", action='store_true', help="Run distribute, default is false.")
|
||||
parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path, default is None')
|
||||
parser.add_argument('--platform', type=str, default='Ascend', choices=['Ascend', 'GPU'],
|
||||
help='Running platform, choose from Ascend, GPU, and default is Ascend.')
|
||||
parser.add_argument('--model', type=str, default='lowercase', help="Model type, default is lowercase")
|
||||
parser.add_argument('--dataset', type=str, default='synth', choices=['synth', 'ic03', 'ic13', 'svt', 'iiit5k'])
|
||||
parser.set_defaults(run_distribute=False)
|
||||
args_opt = parser.parse_args()
|
||||
|
||||
if args_opt.model == 'lowercase':
|
||||
from src.config import config1 as config
|
||||
else:
|
||||
from src.config import config2 as config
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.platform, save_graphs=False)
|
||||
if args_opt.platform == 'Ascend':
|
||||
device_id = int(os.getenv('DEVICE_ID'))
|
||||
context.set_context(device_id=device_id)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
lr_scale = 1
|
||||
if args_opt.run_distribute:
|
||||
if args_opt.platform == 'Ascend':
|
||||
init()
|
||||
lr_scale = 1
|
||||
device_num = int(os.environ.get("RANK_SIZE"))
|
||||
rank = int(os.environ.get("RANK_ID"))
|
||||
else:
|
||||
init()
|
||||
lr_scale = 1
|
||||
device_num = get_group_size()
|
||||
rank = get_rank()
|
||||
context.reset_auto_parallel_context()
|
||||
context.set_auto_parallel_context(device_num=device_num,
|
||||
parallel_mode=ParallelMode.DATA_PARALLEL,
|
||||
gradients_mean=True)
|
||||
else:
|
||||
device_num = 1
|
||||
rank = 0
|
||||
|
||||
max_text_length = config.max_text_length
|
||||
# create dataset
|
||||
dataset = create_dataset(name=args_opt.dataset, dataset_path=args_opt.dataset_path, batch_size=config.batch_size,
|
||||
num_shards=device_num, shard_id=rank, config=config)
|
||||
step_size = dataset.get_dataset_size()
|
||||
# define lr
|
||||
lr_init = config.learning_rate
|
||||
lr = nn.dynamic_lr.cosine_decay_lr(0.0, lr_init, config.epoch_size * step_size, step_size, config.epoch_size)
|
||||
loss = CTCLoss(max_sequence_length=config.num_step,
|
||||
max_label_length=max_text_length,
|
||||
batch_size=config.batch_size)
|
||||
net = CRNN(config)
|
||||
opt = nn.SGD(params=net.trainable_params(), learning_rate=lr, momentum=config.momentum, nesterov=config.nesterov)
|
||||
|
||||
net = WithLossCell(net, loss)
|
||||
net = TrainOneStepCellWithGradClip(net, opt).set_train()
|
||||
# define model
|
||||
model = Model(net)
|
||||
# define callbacks
|
||||
callbacks = [LossMonitor(), TimeMonitor(data_size=step_size)]
|
||||
if config.save_checkpoint:
|
||||
config_ck = CheckpointConfig(save_checkpoint_steps=config.save_checkpoint_steps,
|
||||
keep_checkpoint_max=config.keep_checkpoint_max)
|
||||
save_ckpt_path = os.path.join(config.save_checkpoint_path, 'ckpt_' + str(rank) + '/')
|
||||
ckpt_cb = ModelCheckpoint(prefix="crnn", directory=save_ckpt_path, config=config_ck)
|
||||
callbacks.append(ckpt_cb)
|
||||
model.train(config.epoch_size, dataset, callbacks=callbacks)
|
|
@ -4,11 +4,11 @@
|
|||
- [MobileNetV1 Description](#mobilenetv1-description)
|
||||
- [Model architecture](#model-architecture)
|
||||
- [Dataset](#dataset)
|
||||
- [[Features]](#features)
|
||||
- [[Mixed Precision(Ascend)]](#mixed-precisionascend)
|
||||
- [[Environment Requirements]](#environment-requirements)
|
||||
- [[Script description]](#script-description)
|
||||
- [[Script and sample code]](#script-and-sample-code)
|
||||
- [Features](#features)
|
||||
- [Mixed Precision(Ascend)](#mixed-precisionascend)
|
||||
- [Environment Requirements](#environment-requirements)
|
||||
- [Script description](#script-description)
|
||||
- [Script and sample code](#script-and-sample-code)
|
||||
- [Training process](#training-process)
|
||||
- [Usage](#usage)
|
||||
- [Launch](#launch)
|
||||
|
@ -17,7 +17,7 @@
|
|||
- [Usage](#usage-1)
|
||||
- [Launch](#launch-1)
|
||||
- [Result](#result-1)
|
||||
- [[Model description]](#model-description)
|
||||
- [Model description](#model-description)
|
||||
- [Performance](#performance)
|
||||
- [Training Performance](#training-performance)
|
||||
- [Description of Random Situation](#description-of-random-situation)
|
||||
|
@ -37,6 +37,8 @@ The overall network architecture of MobileNetV1 is show below:
|
|||
|
||||
## [Dataset](#contents)
|
||||
|
||||
Note that you can run the scripts based on the dataset mentioned in original paper or widely used in relevant domain/network architecture. In the following sections, we will introduce how to run the scripts using the related dataset below.
|
||||
|
||||
Dataset used: [ImageNet2012](http://www.image-net.org/)
|
||||
|
||||
- Dataset size 224*224 colorful images in 1000 classes
|
||||
|
|
Loading…
Reference in New Issue