This commit is contained in:
chenhaozhe 2020-12-12 10:39:38 +08:00
parent c70735c19f
commit db674ee7ab
22 changed files with 1876 additions and 7 deletions

View File

@ -17,7 +17,7 @@ Neural Networks Cells.
Pre-defined building blocks or computing units to construct Neural Networks.
"""
from . import layer, loss, optim, metrics, wrap, probability, sparse
from . import layer, loss, optim, metrics, wrap, probability, sparse, dynamic_lr
from .learning_rate_schedule import *
from .dynamic_lr import *
from .cell import Cell, GraphKernel

View File

@ -0,0 +1,244 @@
# Contents
- [Contents](#contents)
- [CRNN Description](#crnn-description)
- [Model Architecture](#model-architecture)
- [Dataset](#dataset)
- [Dataset Prepare](#dataset-prepare)
- [Environment Requirements](#environment-requirements)
- [Quick Start](#quick-start)
- [Script Description](#script-description)
- [Script and Sample Code](#script-and-sample-code)
- [Script Parameters](#script-parameters)
- [Training Script Parameters](#training-script-parameters)
- [Parameters Configuration](#parameters-configuration)
- [Dataset Preparation](#dataset-preparation)
- [Training Process](#training-process)
- [Training](#training)
- [Distributed Training](#distributed-training)
- [Evaluation Process](#evaluation-process)
- [Evaluation](#evaluation)
- [Model Description](#model-description)
- [Performance](#performance)
- [Training Performance](#training-performance)
- [Evaluation Performance](#evaluation-performance)
- [Description of Random Situation](#description-of-random-situation)
- [ModelZoo Homepage](#modelzoo-homepage)
## [CRNN Description](#contents)
CRNN was a neural network for image based sequence recognition and its Application to scene text recognition.In this paper, we investigate the problem of scene text recognition, which is among the most important and challenging tasks in image-based sequence recognition. A novel neural network architecture, which integrates feature extraction, sequence modeling and transcription into a unified framework, is proposed. Compared with previous systems for scene text recognition, the proposed architecture possesses four distinctive properties: (1) It is end-to-end trainable, in contrast to most of the existing algorithms whose components are separately trained and tuned. (2) It naturally handles sequences in arbitrary lengths, involving no character segmentation or horizontal scale normalization. (3) It is not confined to any predefined lexicon and achieves remarkable performances in both lexicon-free and lexicon-based scene text recognition tasks. (4) It generates an effective yet much smaller model, which is more practical for real-world application scenarios.
[Paper](https://arxiv.org/abs/1507.05717): Baoguang Shi, Xiang Bai, Cong Yao, "An End-to-End Trainable Neural Network for Image-based Sequence Recognition and Its Application to Scene Text Recognition", ArXiv, vol. abs/1507.05717, 2015.
## [Model Architecture](#content)
CRNN use a vgg16 structure for feature extraction, the appending with two-layer bidirectional LSTM, finally use CTC to calculate loss. See src/crnn.py for details.
## [Dataset](#content)
Note that you can run the scripts based on the dataset mentioned in original paper or widely used in relevant domain/network architecture. In the following sections, we will introduce how to run the scripts using the related dataset below.
We use five datasets mentioned in the paper.For training, we use the synthetic dataset([MJSynth](https://www.robots.ox.ac.uk/~vgg/data/text/) and [SynthText](https://github.com/ankush-me/SynthText)) released by Jaderberg etal as the training data, which contains 8 millions training images and their corresponding ground truth words.For evaluation, we use four popular benchmarks for scene text recognition, nalely ICDAR 2003([IC03](http://www.iapr-tc11.org/mediawiki/index.php?title=ICDAR_2003_Robust_Reading_Competitions)),ICDAR2013([IC13](https://rrc.cvc.uab.es/?ch=2&com=downloads)),IIIT 5k-word([IIIT5k](https://cvit.iiit.ac.in/research/projects/cvit-projects/the-iiit-5k-word-dataset)),and Street View Text([SVT](http://vision.ucsd.edu/~kai/grocr/)).
### [Dataset Prepare](#content)
For datset `IC03`, `IIIT5k` and `SVT`, the original dataset from the official website can not be used directly in CRNN.
- `IC03`, the text need to be cropped from the original image according to the words.xml.
- `IIIT5k`, the annotation need to be extracted from the matlib data file.
- `SVT`, the text need to be cropped from the original image according to the `train.xml` or `test.xml`.
We provide `convert_ic03.py`, `convert_iiit5k.py`, `convert_svt.py` as exmples for the aboving preprocessing which you can refer to.
## [Environment Requirements](#contents)
- HardwareAscend
- Prepare hardware environment with Ascend processor. If you want to try Ascend, please send the [application form](https://obs-9be7.obs.cn-east-2.myhuaweicloud.com/file/other/Ascend%20Model%20Zoo%E4%BD%93%E9%AA%8C%E8%B5%84%E6%BA%90%E7%94%B3%E8%AF%B7%E8%A1%A8.docx) to ascend@huawei.com. You will be able to have access to related resources once approved.
- Framework
- [MindSpore](https://gitee.com/mindspore/mindspore)
- For more information, please check the resources below
- [MindSpore Tutorials](https://www.mindspore.cn/tutorial/training/en/master/index.html)
- [MindSpore Python API](https://www.mindspore.cn/doc/api_python/en/master/index.html)
## [Quick Start](#contents)
- After the dataset is prepared, you may start running the training or the evaluation scripts as follows:
- Running on Ascend
```shell
# distribute training example in Ascend
$ bash run_distribute_train.sh [DATASET_NAME] [RANK_TABLE_FILE] [DATASET_PATH]
# evaluation example in Ascend
$ bash run_eval.sh [DATASET_NAME] [DATASET_PATH] [CHECKPOINT_PATH] [PLATFORM]
# standalone training example in Ascend
$ bash run_standalone_train.sh [DATASET_NAME] [DATASET_PATH] [PLATFORM]
```
For distributed training, a hccl configuration file with JSON format needs to be created in advance.
Please follow the instructions in the link below:
[hccl_tools](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/utils/hccl_tools).
## [Script Description](#contents)
### [Script and Sample Code](#contents)
```shell
crnn
├── README.md # Descriptions about CRNN
├── convert_ic03.py # Convert the original IC03 daatset
├── convert_iiit5k.py # Convert the original IIIT5K dataset
├── convert_svt.py # Convert the original SVT dataset
├── requirements.txt # Requirements for this dataset
├── scripts
│   ├── run_distribute_train.sh # Launch distributed training in Ascend(8 pcs)
│   ├── run_eval.sh # Launch evaluation
│   └── run_standalone_train.sh # Launch standalone training(1 pcs)
├── src
│   ├── config.py # Parameter configuration
│   ├── crnn.py # crnn network definition
│   ├── crnn_for_train.py # crnn network with grad, loss and gradient clip
│   ├── dataset.py # Data preprocessing for training and evaluation
│   ├── ic03_dataset.py # Data preprocessing for IC03
│   ├── ic13_dataset.py # Data preprocessing for IC13
│   ├── iiit5k_dataset.py # Data preprocessing for IIIT5K
│   ├── loss.py # Ctcloss definition
│   ├── metric.py # accuracy metric for crnn network
│   └── svt_dataset.py # Data preprocessing for SVT
└── train.py # Training script
├── eval.py # Evaluation Script
```
### [Script Parameters](#contents)
#### Training Script Parameters
```shell
# distributed training in Ascend
Usage: bash run_distribute_train.sh [DATASET_NAME] [RANK_TABLE_FILE] [DATASET_PATH]
# standalone training
Usage: bash run_standalone_train.sh [DATASET_NAME] [DATASET_PATH] [PLATFORM]
```
#### Parameters Configuration
Parameters for both training and evaluation can be set in config.py.
```shell
max_text_length": 23, # max number of digits in each
"image_width": 100, # width of text images
"image_height": 32, # height of text images
"batch_size": 64, # batch size of input tensor
"epoch_size": 10, # only valid for taining, which is always 1
"hidden_size": 256, # hidden size in LSTM layers
"learning_rate": 0.02, # initial learning rate
"momentum": 0.95, # momentum of SGD optimizer
"nesterov": True, # enable nesterov in SGD optimizer
"save_checkpoint": True, # whether save checkpoint or not
"save_checkpoint_steps": 1000, # the step interval between two checkpoints.
"keep_checkpoint_max": 30, # only keep the last keep_checkpoint_max
"save_checkpoint_path": "./", # path to save checkpoint
"class_num": 37, # dataset class num
"input_size": 512, # input size for LSTM layer
"num_step": 24, # num step for LSTM layer
"use_dropout": True, # whether use dropout
"blank": 36 # add blank for classification
```
### [Dataset Preparation](#contents)
- You may refer to "Generate dataset" in [Quick Start](#quick-start) to automatically generate a dataset, or you may choose to generate a text image dataset by yourself.
## [Training Process](#contents)
- Set options in `config.py`, including learning rate and other network hyperparameters. Click [MindSpore dataset preparation tutorial](https://www.mindspore.cn/tutorial/training/zh-CN/master/use/data_preparation.html) for more information about dataset.
### [Training](#contents)
- Run `run_standalone_train.sh` for non-distributed training of CRNN model, either on Ascend or on GPU.
``` bash
bash run_standalone_train.sh [DATASET_NAME] [DATASET_PATH] [PLATFORM]
```
#### [Distributed Training](#contents)
- Run `run_distribute_train.sh` for distributed training of WarpCTC model on Ascend.
``` bash
bash run_distribute_train.sh [DATASET_NAME] [RANK_TABLE_FILE] [DATASET_PATH]
```
Check the `train_parallel0/log.txt` and you will get outputs as following:
```shell
epoch: 10 step: 14110, loss is 0.0029097411
Epoch time: 2743.688s, per step time: 0.097s
```
## [Evaluation Process](#contents)
### [Evaluation](#contents)
- Run `run_eval.sh` for evaluation.
``` bash
bash run_eval.sh [DATASET_NAME] [DATASET_PATH] [CHECKPOINT_PATH] [PLATFORM]
```
Check the `eval/log.txt` and you will get outputs as following:
```shell
result: {'CRNNAccuracy': (0.806)}
```
## [Model Description](#contents)
### [Performance](#contents)
#### [Training Performance](#contents)
| Parameters | Ascend 910 |
| -------------------------- | --------------------------------------------------|
| Model Version | v1.0 |
| Resource | Ascend 910, CPU 2.60GHz 192cores, Memory 755G |
| uploaded Date | 12/15/2020 (month/day/year) |
| MindSpore Version | 1.0.1 |
| Dataset | Synth |
| Training Parameters | epoch=10, steps per epoch=14110, batch_size = 64 |
| Optimizer | SGD |
| Loss Function | CTCLoss |
| outputs | probability |
| Loss | 0.0029097411 |
| Speed | 118ms/step(8pcs) |
| Total time | 557 mins |
| Parameters (M) | 83M (.ckpt file) |
| Checkpoint for Fine tuning | 20.3M (.ckpt file) |
| Scripts | [Link](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/crnn) | [Link](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/crnn) |
#### [Evaluation Performance](#contents)
| Parameters | SVT | IIIT5K |
| ------------------- | --------------------------- | --------------------------- |
| Model Version | V1.0 | V1.0 |
| Resource | Ascend 910 | Ascend 910 |
| Uploaded Date | 12/15/2020 (month/day/year) | 12/15/2020 (month/day/year) |
| MindSpore Version | 1.0.1 | 1.0.1 |
| Dataset | SVT | IIIT5K |
| batch_size | 1 | 1 |
| outputs | ACC | ACC |
| Accuracy | 80.9% | 80.6% |
| Model for inference | 83M (.ckpt file) | 83M (.ckpt file) |
## [Description of Random Situation](#contents)
In dataset.py, we set the seed inside “create_dataset" function. We also use random seed in train.py for weight initialization.
## [ModelZoo Homepage](#contents)
Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo)

View File

@ -0,0 +1,134 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import os
import argparse
from xml.etree import ElementTree as ET
from PIL import Image
import numpy as np
def init_args():
parser = argparse.ArgumentParser('')
parser.add_argument('-d', '--dataset_dir', type=str, default='./',
help='path to original images')
parser.add_argument('-x', '--xml_file', type=str, default='test.xml',
help='Directory where character dictionaries for the dataset were stored')
parser.add_argument('-o', '--output_dir', type=str, default='./processed',
help='Directory where ord map dictionaries for the dataset were stored')
parser.add_argument('-a', '--output_annotation', type=str, default='./annotation.txt',
help='Directory where ord map dictionaries for the dataset were stored')
return parser.parse_args()
def xml_to_dict(xml_file, save_file=False):
tree = ET.parse(xml_file)
root = tree.getroot()
imgs_labels = []
for ch in root:
im_label = {}
for ch01 in ch:
if ch01.tag in 'taggedRectangles':
# multiple children
rect_list = []
for ch02 in ch01:
rect = {}
rect['location'] = ch02.attrib
rect['label'] = ch02[0].text
rect_list.append(rect)
im_label['rect'] = rect_list
else:
im_label[ch01.tag] = ch01.text
imgs_labels.append(im_label)
if save_file:
np.save("annotation_train.npy", imgs_labels)
return imgs_labels
def image_crop_save(image, location, output_dir):
"""
crop image with location (h,w,x,y)
save cropped image to output directory
"""
# avoid negative value of coordinates in annotation
start_x = np.maximum(location[2], 0)
end_x = start_x + location[1]
start_y = np.maximum(location[3], 0)
end_y = start_y + location[0]
print("image array shape :{}".format(image.shape))
print("crop region ", start_x, end_x, start_y, end_y)
if len(image.shape) == 3:
cropped = image[start_y:end_y, start_x:end_x, :]
else:
cropped = image[start_y:end_y, start_x:end_x]
im = Image.fromarray(np.uint8(cropped))
im.save(output_dir)
def convert():
args = init_args()
if not os.path.exists(args.dataset_dir):
raise ValueError("dataset_dir :{} does not exist".format(args.dataset_dir))
if not os.path.exists(args.xml_file):
raise ValueError("xml_file :{} does not exist".format(args.xml_file))
if not os.path.exists(args.output_dir):
os.makedirs(args.output_dir)
ims_labels_dict = xml_to_dict(args.xml_file, True)
num_images = len(ims_labels_dict)
annotation_list = []
print("Converting annotation, {} images in total ".format(num_images))
for i in range(num_images):
img_label = ims_labels_dict[i]
image_name = img_label['imageName']
rects = img_label['rect']
ext = image_name.split('.')[-1]
name = image_name[:-(len(ext)+1)]
fullpath = os.path.join(args.dataset_dir, image_name)
im_array = np.asarray(Image.open(fullpath))
print("processing image: {}".format(image_name))
for j, rect in enumerate(rects):
location = rect['location']
h = int(float(location['height']))
w = int(float(location['width']))
x = int(float(location['x']))
y = int(float(location['y']))
label = rect['label']
loc = [h, w, x, y]
output_name = name.replace("/", "_") + "_" + str(j) + "_" + label + '.' + ext
output_name = output_name.replace(",", "")
output_file = os.path.join(args.output_dir, output_name)
image_crop_save(im_array, loc, output_file)
ann = output_name + "," + label + ','
annotation_list.append(ann)
ann_file = args.output_annotation
with open(ann_file, 'w') as f:
for line in annotation_list:
txt = line + '\n'
f.write(txt)
if __name__ == "__main__":
convert()

View File

@ -0,0 +1,67 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import argparse
from scipy import io
###############################################
# load testdata
# testdata.mat structure
# test[:][0] : image name
# test[:][1] : label
# test[:][2] : 50 lexicon
# test[:][3] : 1000 lexicon
##############################################
def init_args():
parser = argparse.ArgumentParser('')
parser.add_argument('-m', '--mat_file', type=str, default='testdata.mat',
help='Directory where character dictionaries for the dataset were stored')
parser.add_argument('-o', '--output_dir', type=str, default='./processed',
help='Directory where ord map dictionaries for the dataset were stored')
parser.add_argument('-a', '--output_annotation', type=str, default='./annotation.txt',
help='Directory where ord map dictionaries for the dataset were stored')
return parser.parse_args()
def mat_to_list(mat_file):
ann_ori = io.loadmat(mat_file)
testdata = ann_ori['testdata'][0]
ann_output = []
for elem in testdata:
img_name = elem[0]
label = elem[1]
ann = img_name+',' +label
ann_output.append(ann)
return ann_output
def convert():
args = init_args()
ann_list = mat_to_list(args.mat_file)
ann_file = args.output_annotation
with open(ann_file, 'w') as f:
for line in ann_list:
txt = line + '\n'
f.write(txt)
if __name__ == "__main__":
convert()

View File

@ -0,0 +1,140 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
import os
import argparse
from xml.etree import ElementTree as ET
from PIL import Image
import numpy as np
def init_args():
parser = argparse.ArgumentParser('')
parser.add_argument('-d', '--dataset_dir', type=str, default='./',
help='Directory containing test_features.tfrecords')
parser.add_argument('-x', '--xml_file', type=str, default='test.xml',
help='Directory where character dictionaries for the dataset were stored')
parser.add_argument('-o', '--output_dir', type=str, default='./processed',
help='Directory where ord map dictionaries for the dataset were stored')
return parser.parse_args()
def xml_to_dict(xml_file, save_file=False):
tree = ET.parse(xml_file)
root = tree.getroot()
imgs_labels = []
for ch in root:
im_label = {}
for ch01 in ch:
if ch01.tag in "address":
continue
elif ch01.tag in 'taggedRectangles':
# multiple children
rect_list = []
for ch02 in ch01:
rect = {}
rect['location'] = ch02.attrib
rect['label'] = ch02[0].text
rect_list.append(rect)
im_label['rect'] = rect_list
else:
im_label[ch01.tag] = ch01.text
imgs_labels.append(im_label)
if save_file:
np.save("annotation_train.npy", imgs_labels)
return imgs_labels
def image_crop_save(image, location, output_dir):
"""
crop image with location (h,w,x,y)
save cropped image to output directory
"""
start_x = location[2]
end_x = start_x + location[1]
start_y = location[3]
if start_y < 0:
start_y = 0
end_y = start_y + location[0]
print("image array shape :{}".format(image.shape))
print("crop region ", start_x, end_x, start_y, end_y)
if len(image.shape) == 3:
cropped = image[start_y:end_y, start_x:end_x, :]
else:
cropped = image[start_y:end_y, start_x:end_x]
im = Image.fromarray(np.uint8(cropped))
im.save(output_dir)
def convert():
args = init_args()
if not os.path.exists(args.dataset_dir):
raise ValueError("dataset_dir :{} does not exist".format(args.dataset_dir))
if not os.path.exists(args.xml_file):
raise ValueError("xml_file :{} does not exist".format(args.xml_file))
if not os.path.exists(args.output_dir):
os.makedirs(args.output_dir)
ims_labels_dict = xml_to_dict(args.xml_file, True)
num_images = len(ims_labels_dict)
lexicon_list = []
annotation_list = []
print("Converting annotation, {} images in total ".format(num_images))
for i in range(num_images):
img_label = ims_labels_dict[i]
image_name = img_label['imageName']
lex = img_label['lex']
rects = img_label['rect']
name, ext = image_name.split('.')
fullpath = os.path.join(args.dataset_dir, image_name)
im_array = np.asarray(Image.open(fullpath))
lexicon_list.append(lex)
print("processing image: {}".format(image_name))
for j, rect in enumerate(rects):
rect = rects[j]
location = rect['location']
h = int(location['height'])
w = int(location['width'])
x = int(location['x'])
y = int(location['y'])
label = rect['label']
loc = [h, w, x, y]
output_name = name + "_" + str(j) + "_" + label + '.' + ext
output_file = os.path.join(args.output_dir, output_name)
image_crop_save(im_array, loc, output_file)
ann = output_name + "," + label + ',' + str(i)
annotation_list.append(ann)
lex_file = './lexicon_ann_train.txt'
ann_file = './annotation_train.txt'
with open(lex_file, 'w') as f:
for line in lexicon_list:
txt = line + '\n'
f.write(txt)
with open(ann_file, 'w') as f:
for line in annotation_list:
txt = line + '\n'
f.write(txt)
if __name__ == "__main__":
convert()

View File

@ -0,0 +1,72 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Warpctc evaluation"""
import os
import argparse
from mindspore import context
from mindspore.common import set_seed
from mindspore.train.model import Model
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from src.loss import CTCLoss
from src.dataset import create_dataset
from src.crnn import CRNN
from src.metric import CRNNAccuracy
set_seed(1)
parser = argparse.ArgumentParser(description="CRNN eval")
parser.add_argument("--dataset_path", type=str, default=None, help="Dataset, default is None.")
parser.add_argument("--checkpoint_path", type=str, default=None, help="checkpoint file path, default is None")
parser.add_argument('--platform', type=str, default='Ascend', choices=['Ascend', 'GPU'],
help='Running platform, choose from Ascend, GPU, and default is Ascend.')
parser.add_argument('--model', type=str, default='lowcase', help="Model type, default is uppercase")
parser.add_argument('--dataset', type=str, default='synth', choices=['synth', 'ic03', 'ic13', 'svt', 'iiit5k'])
args_opt = parser.parse_args()
if args_opt.model == 'lowcase':
from src.config import config1 as config
else:
from src.config import config2 as config
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.platform, save_graphs=False)
if args_opt.platform == 'Ascend':
device_id = int(os.getenv('DEVICE_ID'))
context.set_context(device_id=device_id)
if __name__ == '__main__':
config.batch_size = 1
max_text_length = config.max_text_length
input_size = config.input_size
# create dataset
dataset = create_dataset(name=args_opt.dataset,
dataset_path=args_opt.dataset_path,
batch_size=config.batch_size,
is_training=False,
config=config)
step_size = dataset.get_dataset_size()
loss = CTCLoss(max_sequence_length=config.num_step,
max_label_length=max_text_length,
batch_size=config.batch_size)
net = CRNN(config)
# load checkpoint
param_dict = load_checkpoint(args_opt.checkpoint_path)
load_param_into_net(net, param_dict)
net.set_train(False)
# define model
model = Model(net, loss_fn=loss, metrics={'CRNNAccuracy': CRNNAccuracy(config)})
# start evaluation
res = model.eval(dataset, dataset_sink_mode=args_opt.platform == 'Ascend')
print("result:", res, flush=True)

View File

@ -0,0 +1 @@
python-Levenshtein

View File

@ -0,0 +1,62 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 3 ]; then
echo "Usage: sh run_distribute_train.sh [DATASET_NAME] [RANK_TABLE_FILE] [DATASET_PATH]"
exit 1
fi
get_real_path() {
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
DATASET_NAME=$1
PATH1=$(get_real_path $2)
PATH2=$(get_real_path $3)
if [ ! -f $PATH1 ]; then
echo "error: RANK_TABLE_FILE=$PATH1 is not a file"
exit 1
fi
if [ ! -d $PATH2 ]; then
echo "error: DATASET_PATH=$PATH2 is not a directory"
exit 1
fi
ulimit -u unlimited
export DEVICE_NUM=8
export RANK_SIZE=8
export RANK_TABLE_FILE=$PATH1
for ((i = 0; i < ${DEVICE_NUM}; i++)); do
export DEVICE_ID=$i
export RANK_ID=$i
rm -rf ./train_parallel$i
mkdir ./train_parallel$i
cp ../*.py ./train_parallel$i
cp *.sh ./train_parallel$i
cp -r ../src ./train_parallel$i
cd ./train_parallel$i || exit
echo "start training for rank $RANK_ID, device $DEVICE_ID"
env >env.log
python train.py --platform=Ascend --dataset_path=$PATH2 --run_distribute --dataset=$DATASET_NAME > log.txt 2>&1 &
cd ..
done

View File

@ -0,0 +1,89 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 4 ]; then
echo "Usage: sh run_eval.sh [DATASET_NAME] [DATASET_PATH] [CHECKPOINT_PATH] [PLATFORM]"
exit 1
fi
get_real_path() {
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
DATASET_NAME=$1
PATH1=$(get_real_path $2)
PATH2=$(get_real_path $3)
PLATFORM=$4
if [ ! -d $PATH1 ]; then
echo "error: DATASET_PATH=$PATH1 is not a directory"
exit 1
fi
if [ ! -f $PATH2 ]; then
echo "error: CHECKPOINT_PATH=$PATH2 is not a file"
exit 1
fi
run_ascend() {
ulimit -u unlimited
export DEVICE_NUM=1
export DEVICE_ID=0
export RANK_SIZE=$DEVICE_NUM
export RANK_ID=0
if [ -d "eval" ]; then
rm -rf ./eval
fi
mkdir ./eval
cp ../*.py ./eval
cp -r ../src ./eval
cd ./eval || exit
env >env.log
echo "start evaluation for device $DEVICE_ID"
python eval.py --dataset=$DATASET_NAME --dataset_path=$1 --checkpoint_path=$2 --platform=Ascend > log.txt 2>&1 &
cd ..
}
run_gpu() {
if [ -d "eval" ]; then
rm -rf ./eval
fi
mkdir ./eval
cp ../*.py ./eval
cp -r ../src ./eval
cd ./eval || exit
env >env.log
python eval.py --dataset=$DATASET_NAME \
--dataset_path=$1 \
--checkpoint_path=$2 \
--platform=GPU \
--dataset=$DATASET_NAME > log.txt 2>&1 &
cd ..
}
if [ "Ascend" == $PLATFORM ]; then
run_ascend $PATH1 $PATH2
elif [ "GPU" == $PLATFORM ]; then
run_gpu $PATH1 $PATH2
else
echo "error: PLATFORM=$PLATFORM is not support, only support Ascend and GPU."
fi

View File

@ -0,0 +1,73 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 3 ]; then
echo "Usage: sh run_standalone_train.sh [DATASET_NAME] [DATASET_PATH] [PLATFORM]"
exit 1
fi
get_real_path() {
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
DATASET_NAME=$1
PATH1=$(get_real_path $2)
PLATFORM=$3
if [ ! -d $PATH1 ]; then
echo "error: DATASET_PATH=$PATH1 is not a directory"
exit 1
fi
export DEVICE_ID=0
run_ascend() {
ulimit -u unlimited
export DEVICE_NUM=1
export RANK_ID=0
export RANK_SIZE=1
echo "start training for device $DEVICE_ID"
env >env.log
python train.py --dataset=$DATASET_NAME --dataset_path=$1 --platform=Ascend > log.txt 2>&1 &
cd ..
}
run_gpu() {
env >env.log
python train.py --dataset=$DATASET_NAME --dataset_path=$1 --platform=GPU > log.txt 2>&1 &
cd ..
}
if [ -d "train" ]; then
rm -rf ./train
fi
WORKDIR=./train$(DEVICE_ID)
mkdir $WORKDIR
cp ../*.py $WORKDIR
cp -r ../src $WORKDIR
cd $WORKDIR || exit
if [ "Ascend" == $PLATFORM ]; then
run_ascend $PATH1
elif [ "GPU" == $PLATFORM ]; then
run_gpu $PATH1
else
echo "error: PLATFORM=$PLATFORM is not support, only support Ascend and GPU."
fi

View File

@ -0,0 +1,42 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Network parameters."""
from easydict import EasyDict
label_dict = "abcdefghijklmnopqrstuvwxyz0123456789"
# use for low case number
config1 = EasyDict({
"max_text_length": 23,
"image_width": 100,
"image_height": 32,
"batch_size": 64,
"epoch_size": 10,
"hidden_size": 256,
"learning_rate": 0.02,
"momentum": 0.95,
"nesterov": True,
"save_checkpoint": True,
"save_checkpoint_steps": 1000,
"keep_checkpoint_max": 30,
"save_checkpoint_path": "./",
"class_num": 37,
"input_size": 512,
"num_step": 24,
"use_dropout": True,
"blank": 36
})

View File

@ -0,0 +1,171 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Warpctc network definition."""
import numpy as np
import mindspore.nn as nn
from mindspore import Tensor, Parameter
from mindspore.common import dtype as mstype
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore.common.initializer import TruncatedNormal
def _bn(channel):
return nn.BatchNorm2d(channel, eps=1e-4, momentum=0.9, gamma_init=1, beta_init=0, moving_mean_init=0,
moving_var_init=1)
class Conv(nn.Cell):
def __init__(self, in_channel, out_channel, kernel_size=3, stride=1, use_bn=False, pad_mode='same'):
super(Conv, self).__init__()
self.conv = nn.Conv2d(in_channel, out_channel, kernel_size=kernel_size, stride=stride,
padding=0, pad_mode=pad_mode, weight_init=TruncatedNormal(0.02))
self.bn = _bn(out_channel)
self.Relu = nn.ReLU()
self.use_bn = use_bn
def construct(self, x):
out = self.conv(x)
if self.use_bn:
out = self.bn(out)
out = self.Relu(out)
return out
class VGG(nn.Cell):
"""VGG Network structure"""
def __init__(self, is_training=True):
super(VGG, self).__init__()
self.conv1 = Conv(3, 64, use_bn=True)
self.conv2 = Conv(64, 128, use_bn=True)
self.conv3 = Conv(128, 256, use_bn=True)
self.conv4 = Conv(256, 256, use_bn=True)
self.conv5 = Conv(256, 512, use_bn=True)
self.conv6 = Conv(512, 512, use_bn=True)
self.conv7 = Conv(512, 512, kernel_size=2, pad_mode='valid', use_bn=True)
self.maxpool2d1 = nn.MaxPool2d(kernel_size=2, stride=2, pad_mode='same')
self.maxpool2d2 = nn.MaxPool2d(kernel_size=(2, 1), stride=(2, 1), pad_mode='same')
self.bn1 = _bn(512)
def construct(self, x):
x = self.conv1(x)
x = self.maxpool2d1(x)
x = self.conv2(x)
x = self.maxpool2d1(x)
x = self.conv3(x)
x = self.conv4(x)
x = self.maxpool2d2(x)
x = self.conv5(x)
x = self.conv6(x)
x = self.maxpool2d2(x)
x = self.conv7(x)
return x
class CRNN(nn.Cell):
"""
Define a CRNN network which contains Bidirectional LSTM layers and vgg layer.
Args:
input_size(int): Size of time sequence. Usually, the input_size is equal to three times of image height for
text images.
batch_size(int): batch size of input data, default is 64
hidden_size(int): the hidden size in LSTM layers, default is 512
"""
def __init__(self, config):
super(CRNN, self).__init__()
self.batch_size = config.batch_size
self.input_size = config.input_size
self.hidden_size = config.hidden_size
self.num_classes = config.class_num
self.reshape = P.Reshape()
self.cast = P.Cast()
k = (1 / self.hidden_size) ** 0.5
self.rnn1 = P.DynamicRNN(forget_bias=0.0)
self.rnn1_bw = P.DynamicRNN(forget_bias=0.0)
self.rnn2 = P.DynamicRNN(forget_bias=0.0)
self.rnn2_bw = P.DynamicRNN(forget_bias=0.0)
w1 = np.random.uniform(-k, k, (self.input_size + self.hidden_size, 4 * self.hidden_size))
self.w1 = Parameter(w1.astype(np.float16), name="w1")
w2 = np.random.uniform(-k, k, (2 * self.hidden_size + self.hidden_size, 4 * self.hidden_size))
self.w2 = Parameter(w2.astype(np.float16), name="w2")
w1_bw = np.random.uniform(-k, k, (self.input_size + self.hidden_size, 4 * self.hidden_size))
self.w1_bw = Parameter(w1_bw.astype(np.float16), name="w1_bw")
w2_bw = np.random.uniform(-k, k, (2 * self.hidden_size + self.hidden_size, 4 * self.hidden_size))
self.w2_bw = Parameter(w2_bw.astype(np.float16), name="w2_bw")
self.b1 = Parameter(np.random.uniform(-k, k, (4 * self.hidden_size)).astype(np.float16), name="b1")
self.b2 = Parameter(np.random.uniform(-k, k, (4 * self.hidden_size)).astype(np.float16), name="b2")
self.b1_bw = Parameter(np.random.uniform(-k, k, (4 * self.hidden_size)).astype(np.float16), name="b1_bw")
self.b2_bw = Parameter(np.random.uniform(-k, k, (4 * self.hidden_size)).astype(np.float16), name="b2_bw")
self.h1 = Tensor(np.zeros(shape=(1, self.batch_size, self.hidden_size)).astype(np.float16))
self.h2 = Tensor(np.zeros(shape=(1, self.batch_size, self.hidden_size)).astype(np.float16))
self.h1_bw = Tensor(np.zeros(shape=(1, self.batch_size, self.hidden_size)).astype(np.float16))
self.h2_bw = Tensor(np.zeros(shape=(1, self.batch_size, self.hidden_size)).astype(np.float16))
self.c1 = Tensor(np.zeros(shape=(1, self.batch_size, self.hidden_size)).astype(np.float16))
self.c2 = Tensor(np.zeros(shape=(1, self.batch_size, self.hidden_size)).astype(np.float16))
self.c1_bw = Tensor(np.zeros(shape=(1, self.batch_size, self.hidden_size)).astype(np.float16))
self.c2_bw = Tensor(np.zeros(shape=(1, self.batch_size, self.hidden_size)).astype(np.float16))
self.fc_weight = np.random.random((self.num_classes, self.hidden_size)).astype(np.float32)
self.fc_bias = np.random.random((self.num_classes)).astype(np.float32)
self.fc = nn.Dense(in_channels=self.hidden_size, out_channels=self.num_classes,
weight_init=Tensor(self.fc_weight), bias_init=Tensor(self.fc_bias))
self.fc.to_float(mstype.float32)
self.expand_dims = P.ExpandDims()
self.concat = P.Concat()
self.transpose = P.Transpose()
self.squeeze = P.Squeeze(axis=0)
self.vgg = VGG()
self.reverse_seq1 = P.ReverseSequence(batch_dim=1, seq_dim=0)
self.reverse_seq2 = P.ReverseSequence(batch_dim=1, seq_dim=0)
self.reverse_seq3 = P.ReverseSequence(batch_dim=1, seq_dim=0)
self.reverse_seq4 = P.ReverseSequence(batch_dim=1, seq_dim=0)
self.seq_length = Tensor(np.ones((self.batch_size), np.int32) * config.num_step, mstype.int32)
self.concat1 = P.Concat(axis=2)
self.dropout = nn.Dropout(0.5)
self.rnn_dropout = nn.Dropout(0.9)
self.use_dropout = config.use_dropout
def construct(self, x):
x = self.vgg(x)
x = self.cast(x, mstype.float16)
x = self.reshape(x, (self.batch_size, self.input_size, -1))
x = self.transpose(x, (2, 0, 1))
bw_x = self.reverse_seq1(x, self.seq_length)
y1, _, _, _, _, _, _, _ = self.rnn1(x, self.w1, self.b1, None, self.h1, self.c1)
y1_bw, _, _, _, _, _, _, _ = self.rnn1_bw(bw_x, self.w1_bw, self.b1_bw, None, self.h1_bw, self.c1_bw)
y1_bw = self.reverse_seq2(y1_bw, self.seq_length)
y1_out = self.concat1((y1, y1_bw))
if self.use_dropout:
y1_out = self.rnn_dropout(y1_out)
y2, _, _, _, _, _, _, _ = self.rnn2(y1_out, self.w2, self.b2, None, self.h2, self.c2)
bw_y = self.reverse_seq3(y1_out, self.seq_length)
y2_bw, _, _, _, _, _, _, _ = self.rnn2(bw_y, self.w2_bw, self.b2_bw, None, self.h2_bw, self.c2_bw)
y2_bw = self.reverse_seq4(y2_bw, self.seq_length)
y2_out = self.concat1((y2, y2_bw))
if self.use_dropout:
y2_out = self.dropout(y2_out)
output = ()
for i in range(F.shape(y2_out)[0]):
y2_after_fc = self.fc(self.squeeze(y2[i:i+1:1]))
y2_after_fc = self.expand_dims(y2_after_fc, 0)
output += (y2_after_fc,)
output = self.concat(output)
return output

View File

@ -0,0 +1,114 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Automatic differentiation with grad clip."""
import numpy as np
from mindspore.parallel._utils import (_get_device_num, _get_gradients_mean,
_get_parallel_mode)
from mindspore.context import ParallelMode
from mindspore.common import dtype as mstype
from mindspore.ops import composite as C
from mindspore.ops import functional as F
from mindspore.ops import operations as P
from mindspore.nn.cell import Cell
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
import mindspore.nn as nn
from mindspore.common.tensor import Tensor
compute_norm = C.MultitypeFuncGraph("compute_norm")
@compute_norm.register("Tensor")
def _compute_norm(grad):
norm = nn.Norm()
norm = norm(F.cast(grad, mstype.float32))
ret = F.expand_dims(F.cast(norm, mstype.float32), 0)
return ret
grad_div = C.MultitypeFuncGraph("grad_div")
@grad_div.register("Tensor", "Tensor")
def _grad_div(val, grad):
div = P.RealDiv()
mul = P.Mul()
scale = div(10.0, val)
ret = mul(grad, scale)
return ret
class TrainOneStepCellWithGradClip(Cell):
"""
Network training package class.
Wraps the network with an optimizer. The resulting Cell be trained with input data and label.
Backward graph with grad clip will be created in the construct function to do parameter updating.
Different parallel modes are available to run the training.
Args:
network (Cell): The training network.
optimizer (Cell): Optimizer for updating the weights.
sens (Number): The scaling number to be filled as the input of backpropagation. Default value is 1.0.
Inputs:
- data (Tensor) - Tensor of shape :(N, ...).
- label (Tensor) - Tensor of shape :(N, ...).
Outputs:
Tensor, a scalar Tensor with shape :math:`()`.
"""
def __init__(self, network, optimizer, sens=1.0):
super(TrainOneStepCellWithGradClip, self).__init__(auto_prefix=False)
self.network = network
self.network.set_grad()
self.network.add_flags(defer_inline=True)
self.weights = optimizer.parameters
self.optimizer = optimizer
self.grad = C.GradOperation(get_by_list=True, sens_param=True)
self.sens = sens
self.reducer_flag = False
self.grad_reducer = None
self.hyper_map = C.HyperMap()
self.greater = P.Greater()
self.select = P.Select()
self.norm = nn.Norm(keep_dims=True)
self.dtype = P.DType()
self.cast = P.Cast()
self.concat = P.Concat(axis=0)
self.ten = Tensor(np.array([10.0]).astype(np.float32))
parallel_mode = _get_parallel_mode()
if parallel_mode in (ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL):
self.reducer_flag = True
if self.reducer_flag:
mean = _get_gradients_mean()
degree = _get_device_num()
self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree)
def construct(self, data, label):
weights = self.weights
loss = self.network(data, label)
sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens)
grads = self.grad(self.network, weights)(data, label, sens)
norm = self.hyper_map(F.partial(compute_norm), grads)
norm = self.concat(norm)
norm = self.norm(norm)
cond = self.greater(norm, self.cast(self.ten, self.dtype(norm)))
clip_val = self.select(cond, norm, self.cast(self.ten, self.dtype(norm)))
grads = self.hyper_map(F.partial(grad_div, clip_val), grads)
if self.reducer_flag:
# apply grad reducer on grads
grads = self.grad_reducer(grads)
return F.depend(loss, self.optimizer(grads))

View File

@ -0,0 +1,121 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Dataset preprocessing."""
import os
import numpy as np
import mindspore.common.dtype as mstype
import mindspore.dataset.engine as de
import mindspore.dataset.transforms.c_transforms as C
import mindspore.dataset.vision.c_transforms as vc
from PIL import Image, ImageFile
from src.config import config1, label_dict
from src.ic03_dataset import IC03Dataset
from src.ic13_dataset import IC13Dataset
from src.iiit5k_dataset import IIIT5KDataset
from src.svt_dataset import SVTDataset
ImageFile.LOAD_TRUNCATED_IMAGES = True
class CaptchaDataset:
"""
create train or evaluation dataset for crnn
Args:
img_root_dir(str): root path of images
max_text_length(int): max number of digits in images.
device_target(str): platform of training, support Ascend and GPU.
"""
def __init__(self, img_root_dir, is_training=True, config=config1):
if not os.path.exists(img_root_dir):
raise RuntimeError("the input image dir {} is invalid!".format(img_root_dir))
self.img_root_dir = img_root_dir
if is_training:
self.imgslist = os.path.join(self.img_root_dir, 'annotation_train.txt')
else:
self.imgslist = os.path.join(self.img_root_dir, 'annotation_test.txt')
self.lexicon_file = os.path.join(self.img_root_dir, 'lexicon.txt')
with open(self.lexicon_file, 'r') as f:
self.lexicons = [line.strip('\n') for line in f]
f.close()
self.img_names = {}
self.img_list = []
with open(self.imgslist, 'r') as f:
for line in f:
img_name, label_index = line.strip('\n').split(" ")
self.img_list.append(img_name)
self.img_names[img_name] = self.lexicons[int(label_index)]
f.close()
self.max_text_length = config.max_text_length
self.blank = config.blank
self.class_num = config.class_num
def __len__(self):
return len(self.img_names)
def __getitem__(self, item):
img_name = self.img_list[item]
im = Image.open(os.path.join(self.img_root_dir, img_name))
im = im.convert("RGB")
r, g, b = im.split()
im = Image.merge("RGB", (b, g, r))
image = np.array(im)
label_str = self.img_names[img_name]
label = []
for c in label_str:
if c in label_dict:
label.append(label_dict.index(c))
label.extend([int(self.blank)] * (self.max_text_length - len(label)))
label = np.array(label)
return image, label
def create_dataset(name, dataset_path, batch_size=1, num_shards=1, shard_id=0, is_training=True, config=config1):
"""
create train or evaluation dataset for crnn
Args:
dataset_path(int): dataset path
batch_size(int): batch size of generated dataset, default is 1
num_shards(int): number of devices
shard_id(int): rank id
device_target(str): platform of training, support Ascend and GPU
"""
if name == 'synth':
dataset = CaptchaDataset(dataset_path, is_training, config)
elif name == 'ic03':
dataset = IC03Dataset(dataset_path, "annotation.txt", config, True, 3)
elif name == 'ic13':
dataset = IC13Dataset(dataset_path, "Challenge2_Test_Task3_GT.txt", config)
elif name == 'svt':
dataset = SVTDataset(dataset_path, config)
elif name == 'iiit5k':
dataset = IIIT5KDataset(dataset_path, "annotation.txt", config)
else:
raise ValueError(f"unsupported dataset name: {name}")
ds = de.GeneratorDataset(dataset, ["image", "label"], shuffle=True, num_shards=num_shards, shard_id=shard_id)
image_trans = [
vc.Resize((config.image_height, config.image_width)),
vc.Normalize([127.5, 127.5, 127.5], std=[127.5, 127.5, 127.5]),
vc.HWC2CHW()
]
label_trans = [
C.TypeCast(mstype.int32)
]
ds = ds.map(operations=image_trans, input_columns=["image"], num_parallel_workers=8)
ds = ds.map(operations=label_trans, input_columns=["label"], num_parallel_workers=8)
ds = ds.batch(batch_size, drop_remainder=True)
return ds

View File

@ -0,0 +1,80 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Dataset adaptor for SVT"""
import os
import numpy as np
from PIL import Image, ImageFile
from src.config import config1, label_dict
ImageFile.LOAD_TRUNCATED_IMAGES = True
class IC03Dataset:
"""
create train or evaluation dataset for crnn
Args:
img_root_dir(str): root path of images
max_text_length(int): max number of digits in images.
device_target(str): platform of training, support Ascend and GPU.
"""
def __init__(self, img_root_dir, anno_file="annotation.txt", config=config1, filter_by_dict=True, filter_length=3):
if not os.path.exists(img_root_dir):
raise RuntimeError("the input image dir {} is invalid!".format(img_root_dir))
self.img_root_dir = img_root_dir
anno_file = os.path.join(img_root_dir, anno_file)
self.img_names = {}
self.img_list = []
with open(anno_file, 'r') as f:
for lines in f:
img_name = lines.split(",")[0]
label = lines.split(",")[1].lower()
if len(label) < filter_length:
continue
if filter_by_dict:
flag = True
for c in label:
if c not in label_dict:
flag = False
break
if not flag:
continue
self.img_names[img_name] = label
self.img_list.append(img_name)
self.max_text_length = config.max_text_length
self.blank = config.blank
self.class_num = config.class_num
def __len__(self):
return len(self.img_names)
def __getitem__(self, item):
img_name = self.img_list[item]
im = Image.open(os.path.join(self.img_root_dir, img_name))
im = im.convert("RGB")
r, g, b = im.split()
im = Image.merge("RGB", (b, g, r))
image = np.array(im)
label_str = self.img_names[img_name]
label = []
for c in label_str:
if c in label_dict:
label.append(label_dict.index(c))
label.extend([int(self.blank)] * (self.max_text_length - len(label)))
label = np.array(label)
return image, label

View File

@ -0,0 +1,77 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Dataset adaptor for SVT"""
import os
import numpy as np
from PIL import Image, ImageFile
from src.config import config1, label_dict
ImageFile.LOAD_TRUNCATED_IMAGES = True
class IC13Dataset:
"""
create evaluation dataset for crnn
Args:
img_root_dir(str): root path of images
max_text_length(int): max number of digits in images
device_target(str): platform of training, support Ascend and GPU
"""
def __init__(self, img_root_dir, label_file="", config=config1, filter_by_dict=True, filter_length=3):
if not os.path.exists(img_root_dir):
raise RuntimeError("the input image dir {} is invalid".format(img_root_dir))
self.img_root_dir = img_root_dir
self.label_file = os.path.join(img_root_dir, label_file)
self.img_names = {}
self.img_list = []
self.config = config
with open(self.label_file, 'r') as f:
for lines in f:
img_name = lines.split(",")[0]
label = lines.split("\"")[1].lower()
if len(label) < filter_length:
continue
if filter_by_dict:
flag = True
for c in label:
if c not in label_dict:
flag = False
break
if not flag:
continue
self.img_names[img_name] = label
self.img_list.append(img_name)
f.close()
self.max_text_length = config.max_text_length
self.blank = config.blank
self.class_num = config.class_num
def __len__(self):
return len(self.img_names)
def __getitem__(self, item):
img_name = self.img_list[item]
im = Image.open(os.path.join(self.img_root_dir, img_name))
im = im.convert("RGB")
r, g, b = im.split()
im = Image.merge("RGB", (b, g, r))
image = np.array(im)
label_str = self.img_names[img_name]
label = []
for c in label_str:
if c in label_dict:
label.append(label_dict.index(c))
label.extend([int(self.blank)] * (self.max_text_length - len(label)))
label = np.array(label)
return image, label

View File

@ -0,0 +1,69 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Dataset adaptor for SVT"""
import os
import numpy as np
from PIL import Image, ImageFile
from src.config import config1, label_dict
ImageFile.LOAD_TRUNCATED_IMAGES = True
class IIIT5KDataset:
"""
create train or evaluation dataset for crnn
Args:
img_root_dir(str): root path of images
max_text_length(int): max number of digits in images.
device_target(str): platform of training, support Ascend and GPU.
"""
def __init__(self, img_root_dir, anno_file="annotation.txt", config=config1):
if not os.path.exists(img_root_dir):
raise RuntimeError("the input image dir {} is invalid!".format(img_root_dir))
self.img_root_dir = img_root_dir
anno_file = os.path.join(img_root_dir, anno_file)
self.img_names = {}
self.img_list = []
with open(anno_file, 'r') as f:
for lines in f:
img_name = lines.split(",")[0]
label = lines.split(",")[1].lower()
self.img_names[img_name] = label
self.img_list.append(img_name)
self.max_text_length = config.max_text_length
self.blank = config.blank
self.class_num = config.class_num
def __len__(self):
return len(self.img_names)
def __getitem__(self, item):
img_name = self.img_list[item]
im = Image.open(os.path.join(self.img_root_dir, img_name))
im = im.convert("RGB")
r, g, b = im.split()
im = Image.merge("RGB", (b, g, r))
image = np.array(im)
label_str = self.img_names[img_name]
label = []
for c in label_str:
if c in label_dict:
label.append(label_dict.index(c))
label.extend([int(self.blank)] * (self.max_text_length - len(label)))
label = np.array(label)
return image, label

View File

@ -0,0 +1,49 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""CTC Loss."""
import numpy as np
from mindspore.nn.loss.loss import _Loss
from mindspore import Tensor, Parameter
from mindspore.common import dtype as mstype
from mindspore.ops import operations as P
class CTCLoss(_Loss):
"""
CTCLoss definition
Args:
max_sequence_length(int): max number of sequence length. For text images, the value is equal to image
width
max_label_length(int): max number of label length for each input.
batch_size(int): batch size of input logits
"""
def __init__(self, max_sequence_length, max_label_length, batch_size):
super(CTCLoss, self).__init__()
self.sequence_length = Parameter(Tensor(np.array([max_sequence_length] * batch_size), mstype.int32),
name="sequence_length")
labels_indices = []
for i in range(batch_size):
for j in range(max_label_length):
labels_indices.append([i, j])
self.labels_indices = Parameter(Tensor(np.array(labels_indices), mstype.int64), name="labels_indices")
self.reshape = P.Reshape()
self.ctc_loss = P.CTCLoss(ctc_merge_repeated=True)
def construct(self, logit, label):
labels_values = self.reshape(label, (-1,))
loss, _ = self.ctc_loss(logit, self.labels_indices, labels_values, self.sequence_length)
return loss

View File

@ -0,0 +1,94 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Metric for accuracy evaluation."""
from mindspore import nn
import Levenshtein
label_dict = "abcdefghijklmnopqrstuvwxyz0123456789"
class CRNNAccuracy(nn.Metric):
"""
Define accuracy metric for warpctc network.
"""
def __init__(self, config):
super(CRNNAccuracy).__init__()
self.config = config
self._correct_num = 0
self._total_num = 0
self.blank = config.blank
def clear(self):
self._correct_num = 0
self._total_num = 0
def update(self, *inputs):
if len(inputs) != 2:
raise ValueError('CRNNAccuracy need 2 inputs (y_pred, y), but got {}'.format(len(inputs)))
y_pred = self._convert_data(inputs[0])
y = self._convert_data(inputs[1])
str_pred = self._ctc_greedy_decoder(y_pred)
str_label = self._convert_labels(y)
for pred, label in zip(str_pred, str_label):
print(pred, " :: ", label)
edit_distance = Levenshtein.distance(pred, label)
self._total_num += 1
if edit_distance == 0:
self._correct_num += 1
def eval(self):
if self._total_num == 0:
raise RuntimeError('Accuary can not be calculated, because the number of samples is 0.')
print('correct num: ', self._correct_num, ', total num: ', self._total_num)
sequence_accurancy = self._correct_num / self._total_num
return sequence_accurancy
def _arr2char(self, inputs):
string = ""
for i in inputs:
if i < self.blank:
string += label_dict[i]
return string
def _convert_labels(self, inputs):
str_list = []
for label in inputs:
str_temp = self._arr2char(label)
str_list.append(str_temp)
return str_list
def _ctc_greedy_decoder(self, y_pred):
"""
parse predict result to labels
"""
indices = []
seq_len, batch_size, _ = y_pred.shape
indices = y_pred.argmax(axis=2)
lens = [seq_len] * batch_size
pred_labels = []
for i in range(batch_size):
idx = indices[:, i]
last_idx = self.blank
pred_label = []
for j in range(lens[i]):
cur_idx = idx[j]
if cur_idx not in [last_idx, self.blank]:
pred_label.append(cur_idx)
last_idx = cur_idx
pred_labels.append(pred_label)
str_results = []
for i in pred_labels:
str_results.append(self._arr2char(i))
return str_results

View File

@ -0,0 +1,67 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Dataset adaptor for SVT"""
import os
import numpy as np
from PIL import Image, ImageFile
from src.config import config1, label_dict
ImageFile.LOAD_TRUNCATED_IMAGES = True
class SVTDataset:
"""
create train or evaluation dataset for crnn
Args:
img_root_dir(str): root path of images
max_text_length(int): max number of digits in images.
device_target(str): platform of training, support Ascend and GPU.
"""
def __init__(self, img_root_dir, config=config1):
if not os.path.exists(img_root_dir):
raise RuntimeError("the input image dir {} is invalid!".format(img_root_dir))
self.img_root_dir = img_root_dir
file_list = os.listdir(img_root_dir)
self.img_names = {}
self.img_list = []
for f in file_list:
label = f.split(".jpg")[0]
label = label.split("_")[-1].lower()
self.img_names[f] = label
self.img_list.append(f)
self.max_text_length = config.max_text_length
self.blank = config.blank
self.class_num = config.class_num
def __len__(self):
return len(self.img_names)
def __getitem__(self, item):
img_name = self.img_list[item]
im = Image.open(os.path.join(self.img_root_dir, img_name))
im = im.convert("RGB")
r, g, b = im.split()
im = Image.merge("RGB", (b, g, r))
image = np.array(im)
label_str = self.img_names[img_name]
label = []
for c in label_str:
if c in label_dict:
label.append(label_dict.index(c))
label.extend([int(self.blank)] * (self.max_text_length - len(label)))
label = np.array(label)
return image, label

View File

@ -0,0 +1,101 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""crnn training"""
import os
import argparse
import mindspore.nn as nn
from mindspore import context
from mindspore.common import set_seed
from mindspore.train.model import Model
from mindspore.context import ParallelMode
from mindspore.nn.wrap import WithLossCell
from mindspore.train.callback import TimeMonitor, LossMonitor, CheckpointConfig, ModelCheckpoint
from mindspore.communication.management import init, get_group_size, get_rank
from src.loss import CTCLoss
from src.dataset import create_dataset
from src.crnn import CRNN
from src.crnn_for_train import TrainOneStepCellWithGradClip
set_seed(1)
parser = argparse.ArgumentParser(description="crnn training")
parser.add_argument("--run_distribute", action='store_true', help="Run distribute, default is false.")
parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path, default is None')
parser.add_argument('--platform', type=str, default='Ascend', choices=['Ascend', 'GPU'],
help='Running platform, choose from Ascend, GPU, and default is Ascend.')
parser.add_argument('--model', type=str, default='lowercase', help="Model type, default is lowercase")
parser.add_argument('--dataset', type=str, default='synth', choices=['synth', 'ic03', 'ic13', 'svt', 'iiit5k'])
parser.set_defaults(run_distribute=False)
args_opt = parser.parse_args()
if args_opt.model == 'lowercase':
from src.config import config1 as config
else:
from src.config import config2 as config
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.platform, save_graphs=False)
if args_opt.platform == 'Ascend':
device_id = int(os.getenv('DEVICE_ID'))
context.set_context(device_id=device_id)
if __name__ == '__main__':
lr_scale = 1
if args_opt.run_distribute:
if args_opt.platform == 'Ascend':
init()
lr_scale = 1
device_num = int(os.environ.get("RANK_SIZE"))
rank = int(os.environ.get("RANK_ID"))
else:
init()
lr_scale = 1
device_num = get_group_size()
rank = get_rank()
context.reset_auto_parallel_context()
context.set_auto_parallel_context(device_num=device_num,
parallel_mode=ParallelMode.DATA_PARALLEL,
gradients_mean=True)
else:
device_num = 1
rank = 0
max_text_length = config.max_text_length
# create dataset
dataset = create_dataset(name=args_opt.dataset, dataset_path=args_opt.dataset_path, batch_size=config.batch_size,
num_shards=device_num, shard_id=rank, config=config)
step_size = dataset.get_dataset_size()
# define lr
lr_init = config.learning_rate
lr = nn.dynamic_lr.cosine_decay_lr(0.0, lr_init, config.epoch_size * step_size, step_size, config.epoch_size)
loss = CTCLoss(max_sequence_length=config.num_step,
max_label_length=max_text_length,
batch_size=config.batch_size)
net = CRNN(config)
opt = nn.SGD(params=net.trainable_params(), learning_rate=lr, momentum=config.momentum, nesterov=config.nesterov)
net = WithLossCell(net, loss)
net = TrainOneStepCellWithGradClip(net, opt).set_train()
# define model
model = Model(net)
# define callbacks
callbacks = [LossMonitor(), TimeMonitor(data_size=step_size)]
if config.save_checkpoint:
config_ck = CheckpointConfig(save_checkpoint_steps=config.save_checkpoint_steps,
keep_checkpoint_max=config.keep_checkpoint_max)
save_ckpt_path = os.path.join(config.save_checkpoint_path, 'ckpt_' + str(rank) + '/')
ckpt_cb = ModelCheckpoint(prefix="crnn", directory=save_ckpt_path, config=config_ck)
callbacks.append(ckpt_cb)
model.train(config.epoch_size, dataset, callbacks=callbacks)

View File

@ -4,11 +4,11 @@
- [MobileNetV1 Description](#mobilenetv1-description)
- [Model architecture](#model-architecture)
- [Dataset](#dataset)
- [[Features]](#features)
- [[Mixed Precision(Ascend)]](#mixed-precisionascend)
- [[Environment Requirements]](#environment-requirements)
- [[Script description]](#script-description)
- [[Script and sample code]](#script-and-sample-code)
- [Features](#features)
- [Mixed Precision(Ascend)](#mixed-precisionascend)
- [Environment Requirements](#environment-requirements)
- [Script description](#script-description)
- [Script and sample code](#script-and-sample-code)
- [Training process](#training-process)
- [Usage](#usage)
- [Launch](#launch)
@ -17,7 +17,7 @@
- [Usage](#usage-1)
- [Launch](#launch-1)
- [Result](#result-1)
- [[Model description]](#model-description)
- [Model description](#model-description)
- [Performance](#performance)
- [Training Performance](#training-performance)
- [Description of Random Situation](#description-of-random-situation)
@ -37,6 +37,8 @@ The overall network architecture of MobileNetV1 is show below:
## [Dataset](#contents)
Note that you can run the scripts based on the dataset mentioned in original paper or widely used in relevant domain/network architecture. In the following sections, we will introduce how to run the scripts using the related dataset below.
Dataset used: [ImageNet2012](http://www.image-net.org/)
- Dataset size 224*224 colorful images in 1000 classes