!11788 Adding CRNN-Seq2Seq-OCR model to MindSpore model zoo

From: @alashkari
Reviewed-by: 
Signed-off-by:
This commit is contained in:
mindspore-ci-bot 2021-02-22 10:27:53 +08:00 committed by Gitee
commit 3952a57d85
17 changed files with 2134 additions and 0 deletions

View File

@ -0,0 +1,196 @@
# Contents
- [Contents](#contents)
- [CRNN-Seq2Seq-OCR Description](#crnn-seq2seq-ocr-description)
- [Model Architecture](#model-architecture)
- [Dataset](#dataset)
- [Dataset Prepare](#dataset-prepare)
- [Environment Requirements](#environment-requirements)
- [Quick Start](#quick-start)
- [Script Description](#script-description)
- [Script and Sample Code](#script-and-sample-code)
- [Script Parameters](#script-parameters)
- [Training Script Parameters](#training-script-parameters)
- [Parameters Configuration](#parameters-configuration)
- [Dataset Preparation](#dataset-preparation)
- [Training Process](#training-process)
- [Training](#training)
- [Distributed Training](#distributed-training)
- [Evaluation Process](#evaluation-process)
- [Evaluation](#evaluation)
- [Model Description](#model-description)
- [Performance](#performance)
- [Training Performance](#training-performance)
- [Evaluation Performance](#evaluation-performance)
## [CRNN-Seq2Seq-OCR Description](#contents)
CRNN-Seq2Seq-OCR is a neural network model for image based sequence recognition tasks, such as scene text recognition and optical character recognition (OCR). Its architecture is a combination of CNN and sequence to sequence model with attention mechanism.
## [Model Architecture](#content)
CRNN-Seq2Seq-OCR applies a vgg structure to extract features from processed images, following with attention-based encoder and decoder layer, finally utilizes NLL to calculate loss. See src/attention_ocr.py for details.
## [Dataset](#content)
For training and evaluation, we use the French Street Name Signs (FSNS) released by Google as the training data, which contains approximately 1 million training images and their corresponding ground truth words.
## [Environment Requirements](#contents)
- HardwareAscend
- Prepare hardware environment with Ascend processor. If you want to try Ascend, please send the [application form](https://obs-9be7.obs.cn-east-2.myhuaweicloud.com/file/other/Ascend%20Model%20Zoo%E4%BD%93%E9%AA%8C%E8%B5%84%E6%BA%90%E7%94%B3%E8%AF%B7%E8%A1%A8.docx) to ascend@huawei.com. You will be able to have access to related resources once approved.
- Framework
- [MindSpore](https://gitee.com/mindspore/mindspore)
- For more information, please check the resources below
- [MindSpore Tutorials](https://www.mindspore.cn/tutorial/training/en/master/index.html)
- [MindSpore Python API](https://www.mindspore.cn/doc/api_python/en/master/index.html)
## [Quick Start](#contents)
- After the dataset is prepared, you may start running the training or the evaluation scripts as follows:
- Running on Ascend
```shell
# distribute training example in Ascend
$ bash run_distribute_train.sh [RANK_TABLE_FILE] [DATASET_PATH]
# evaluation example in Ascend
$ bash run_eval_ascend.sh [DATASET_PATH] [CHECKPOINT_PATH]
# standalone training example in Ascend
$ bash run_standalone_train.sh [DATASET_NAME] [DATASET_PATH] [PLATFORM]
```
For distributed training, a hccl configuration file with JSON format needs to be created in advance.
Please follow the instructions in the link below:
[hccl_tools](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/utils/hccl_tools).
## [Script Description](#contents)
### [Script and Sample Code](#contents)
```shell
crnn-seq2seq-ocr
├── README.md # Descriptions about CRNN-Seq2Seq-OCR
├── scripts
│   ├── run_distribute_train.sh # Launch distributed training on Ascend(8 pcs)
│   ├── run_eval_ascend.sh # Launch Ascend evaluation
│   └── run_standalone_train.sh # Launch standalone training on Ascend(1 pcs)
├── src
│   ├── attention_ocr.py # CRNN-Seq2Seq-OCR training wrapper
│   ├── cnn.py # VGG network
│   ├── config.py # Parameter configuration
│   ├── create_mindrecord_files.py # Create mindrecord files from images and ground truth
│   ├── dataset.py # Data preprocessing for training and evaluation
│   ├── gru.py # GRU cell wrapper
│   ├── logger.py # Logger configuration
│   ├── lstm.py # LSTM cell wrapper
│   ├── seq2seq.py # CRNN-Seq2Seq-OCR model structure
│   └── utils.py # Utility functions for training and data pre-processing
│   ├── weight_init.py # weight initialization of LSTM and GRU
└── train.py # Training script
├── eval.py # Evaluation Script
```
### [Script Parameters](#contents)
#### Training Script Parameters
```shell
# distributed training on Ascend
Usage: bash run_distribute_train.sh [RANK_TABLE_FILE] [DATASET_PATH]
# standalone training
Usage: bash run_standalone_train.sh [DATASET_PATH]
```
#### Parameters Configuration
Parameters for both training and evaluation can be set in config.py.
### [Dataset Preparation](#contents)
- You may refer to "Generate dataset" in [Quick Start](#quick-start) to automatically generate a dataset, or you may choose to generate a text image dataset by yourself.
## [Training Process](#contents)
- Set options in `config.py`, including learning rate and other network hyperparameters. Click [MindSpore dataset preparation tutorial](https://www.mindspore.cn/tutorial/training/zh-CN/master/use/data_preparation.html) for more information about dataset.
### [Training](#contents)
- Run `run_standalone_train.sh` for non-distributed training of CRNN-Seq2Seq-OCR model, only support Ascend now.
``` bash
bash run_standalone_train.sh [DATASET_PATH]
```
#### [Distributed Training](#contents)
- Run `run_distribute_train.sh` for distributed training of CRNN-Seq2Seq-OCR model on Ascend.
``` bash
bash run_distribute_train.sh [RANK_TABLE_FILE] [DATASET_PATH]
```
Check the `train_parallel0/log.txt` and you will get outputs as following:
```shell
epoch: 20 step: 4080, loss is 1.56112
epoch: 20 step: 4081, loss is 1.6368448
epoch time: 1559886.096 ms, per step time: 382.231 ms
```
## [Evaluation Process](#contents)
### [Evaluation](#contents)
- Run `run_eval_ascend.sh` for evaluation on Ascend.
``` bash
bash run_eval_ascend.sh [DATASET_PATH] [CHECKPOINT_PATH]
```
Check the `eval/log` and you will get outputs as following:
```shell
character precision = 0.967522
Annotation precision precision = 0.635204
```
# Model Description
## Performance
### Evaluation Performance
| Parameters | Ascend |
| -------------------------- | ----------------------------------------------------------- |
| Model Version | V1 |
| Resource | Ascend 910 CPU 2.60GHz192coresMemory755G |
| uploaded Date | 02/11/2021 (month/day/year) |
| MindSpore Version | 1.2.0 |
| Dataset | FSNS |
| Training Parameters | epoch=20, batch_size=32 |
| Optimizer | SGD |
| Loss Function | Negative Log Likelihood |
| Speed | 1pc: 355 ms/step; 8pcs: 385 ms/step |
| Total time | 1pc: 64 hours; 8pcs: 9 hours |
| Parameters (M) | 12 |
| Scripts | [crnn_seq2seq_ocr script](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/crnn_seq2seq_ocr) |
### Inference Performance
| Parameters | Ascend |
| ------------------- | --------------------------- |
| Model Version | V1 |
| Resource | Ascend 910 |
| Uploaded Date | 02/11/2021 (month/day/year) |
| MindSpore Version | 1.2.0 |
| Dataset | FSNS |
| batch_size | 32 |
| outputs | Annotation Precision, Character Precision |
| Accuracy | Annotation Precision=63.52%, Character Precision=96.75% |
| Model for inference | 12M (.ckpt file) |

View File

@ -0,0 +1,181 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
CRNN-Seq2Seq-OCR Evaluation.
"""
import os
import codecs
import argparse
import numpy as np
import mindspore.ops.operations as P
import mindspore.common.dtype as mstype
from mindspore.common import set_seed
from mindspore import context, Tensor
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from src.config import config
from src.utils import initialize_vocabulary
from src.dataset import create_ocr_val_dataset
from src.attention_ocr import AttentionOCRInfer
set_seed(1)
def text_standardization(text_in):
"""
replace some particular characters
"""
stand_text = text_in.strip()
stand_text = ' '.join(stand_text.split())
stand_text = stand_text.replace(u'(', u'')
stand_text = stand_text.replace(u')', u'')
stand_text = stand_text.replace(u':', u'')
return stand_text
def LCS_length(str1, str2):
"""
calculate longest common sub-sequence between str1 and str2
"""
if str1 is None or str2 is None:
return 0
len1 = len(str1)
len2 = len(str2)
if len1 == 0 or len2 == 0:
return 0
lcs = [[0 for _ in range(len2 + 1)] for _ in range(2)]
for i in range(1, len1 + 1):
for j in range(1, len2 + 1):
if str1[i - 1] == str2[j - 1]:
lcs[i % 2][j] = lcs[(i - 1) % 2][j - 1] + 1
else:
if lcs[i % 2][j - 1] >= lcs[(i - 1) % 2][j]:
lcs[i % 2][j] = lcs[i % 2][j - 1]
else:
lcs[i % 2][j] = lcs[(i - 1) % 2][j]
return lcs[len1 % 2][-1]
if __name__ == '__main__':
parser = argparse.ArgumentParser(description="CRNN-Seq2Seq-OCR Evaluation")
parser.add_argument("--dataset_path", type=str, default="",
help="Test Dataset path")
parser.add_argument("--checkpoint_path", type=str, default=None,
help="Checkpoint of AttentionOCR (Default:None).")
parser.add_argument("--device_target", type=str, default="Ascend",
help="device where the code will be implemented, default is Ascend")
parser.add_argument("--device_id", type=int, default=0, help="Device id, default: 0.")
args = parser.parse_args()
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=args.device_id)
prefix = "fsns.mindrecord"
mindrecord_dir = args.dataset_path
mindrecord_file = os.path.join(mindrecord_dir, prefix + "0")
print("mindrecord_file", mindrecord_file)
dataset = create_ocr_val_dataset(mindrecord_file, config.eval_batch_size)
data_loader = dataset.create_dict_iterator(num_epochs=1, output_numpy=True)
print("Dataset creation Done!")
#Network
network = AttentionOCRInfer(config.eval_batch_size,
int(config.img_width / 4),
config.encoder_hidden_size,
config.decoder_hidden_size,
config.decoder_output_size,
config.max_length,
config.dropout_p)
ckpt = load_checkpoint(args.checkpoint_path)
load_param_into_net(network, ckpt)
network.set_train(False)
print("Checkpoint loading Done!")
vocab, rev_vocab = initialize_vocabulary(config.vocab_path)
eos_id = config.characters_dictionary.get("eos_id")
sos_id = config.characters_dictionary.get("go_id")
num_correct_char = 0
num_total_char = 0
num_correct_word = 0
num_total_word = 0
correct_file = 'result_correct.txt'
incorrect_file = 'result_incorrect.txt'
with codecs.open(correct_file, 'w', encoding='utf-8') as fp_output_correct, \
codecs.open(incorrect_file, 'w', encoding='utf-8') as fp_output_incorrect:
for data in data_loader:
images = Tensor(data["image"])
decoder_inputs = Tensor(data["decoder_input"])
decoder_targets = Tensor(data["decoder_target"])
decoder_hidden = Tensor(np.zeros((1, config.eval_batch_size, config.decoder_hidden_size),
dtype=np.float16), mstype.float16)
decoder_input = Tensor((np.ones((config.eval_batch_size, 1))*sos_id).astype(np.int32))
encoder_outputs = network.encoder(images)
batch_decoded_label = []
for di in range(decoder_inputs.shape[1]):
decoder_output, decoder_hidden, _ = network.decoder(decoder_input, decoder_hidden, encoder_outputs)
topi = P.Argmax()(decoder_output)
ni = P.ExpandDims()(topi, 1)
decoder_input = ni
topi_id = topi.asnumpy()
batch_decoded_label.append(topi_id)
for b in range(config.eval_batch_size):
text = data["annotation"][b].decode("utf8")
text = text_standardization(text)
decoded_label = list(np.array(batch_decoded_label)[:, b])
decoded_words = []
for idx in decoded_label:
if idx == eos_id:
break
else:
decoded_words.append(rev_vocab[idx])
predict = text_standardization("".join(decoded_words))
if predict == text:
num_correct_word += 1
fp_output_correct.write('\t\t' + text + '\n')
fp_output_correct.write('\t\t' + predict + '\n\n')
print('correctly predicted : pred: {}, gt: {}'.format(predict, text))
else:
fp_output_incorrect.write('\t\t' + text + '\n')
fp_output_incorrect.write('\t\t' + predict + '\n\n')
print('incorrectly predicted : pred: {}, gt: {}'.format(predict, text))
num_total_word += 1
num_correct_char += 2 * LCS_length(text, predict)
num_total_char += len(text) + len(predict)
print('\nnum of correct characters = %d' % (num_correct_char))
print('\nnum of total characters = %d' % (num_total_char))
print('\nnum of correct words = %d' % (num_correct_word))
print('\nnum of total words = %d' % (num_total_word))
print('\ncharacter precision = %f' % (float(num_correct_char) / num_total_char))
print('\nAnnotation precision precision = %f' % (float(num_correct_word) / num_total_word))

View File

@ -0,0 +1,66 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# -ne 2 ]
then
echo "Usage: sh run_distribute_train.sh [RANK_TABLE_FILE] [DATASET_PATH]"
exit 1
fi
get_real_path(){
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
PATH1=$(get_real_path $1)
echo $PATH1
if [ ! -f $PATH1 ]
then
echo "error: RANK_TABLE_FILE=$PATH1 is not a file"
exit 1
fi
PATH2=$(get_real_path $2)
echo $PATH2
if [ ! -f $PATH2 ]
then
echo "error: PRETRAINED_PATH=$PATH2 is not a file"
exit 1
fi
ulimit -u unlimited
export DEVICE_NUM=8
export RANK_SIZE=8
export RANK_TABLE_FILE=$PATH1
for((i=0; i<${DEVICE_NUM}; i++))
do
export RANK_ID=$i
export DEVICE_ID=$i
rm -rf ./train_parallel$i
mkdir ./train_parallel$i
cp ../*.py ./train_parallel$i
cp -r ../src ./train_parallel$i
cd ./train_parallel$i || exit
echo "start training for rank $RANK_ID, device $DEVICE_ID"
env > env.log
python train.py --device_id=$DEVICE_ID --rank_id=$RANK_ID --is_distribute=1 --device_num=$DEVICE_NUM --mindrecord_file=$PATH2 &> log &
cd ..
done

View File

@ -0,0 +1,64 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 2 ]
then
echo "Usage: sh run_eval_ascend.sh [DATASET_PATH] [CHECKPOINT_PATH]"
exit 1
fi
get_real_path(){
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
PATH1=$(get_real_path $1)
PATH2=$(get_real_path $2)
echo $PATH1
echo $PATH2
if [ ! -d $PATH1 ]
then
echo "error: DATASET_PATH=$PATH1 is not a folder"
exit 1
fi
if [ ! -f $PATH2 ]
then
echo "error: CHECKPOINT_PATH=$PATH2 is not a file"
exit 1
fi
export DEVICE_NUM=1
export RANK_SIZE=$DEVICE_NUM
export DEVICE_ID=0
export RANK_ID=0
if [ -d "eval" ];
then
rm -rf ./eval
fi
mkdir ./eval
cp ../*.py ./eval
cp *.sh ./eval
cp -r ../src ./eval
cd ./eval || exit
env > env.log
echo "start eval for device $DEVICE_ID"
python eval.py --device_target="Ascend" --device_id=$DEVICE_ID --dataset_path=$PATH1 --checkpoint_path=$PATH2 &> log &
cd ..

View File

@ -0,0 +1,58 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# -ne 1 ]
then
echo "Usage: sh run_standalone_train_ascend.sh [DATASET_PATH]"
exit 1
fi
get_real_path(){
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
PATH1=$(get_real_path $1)
echo $PATH1
if [ ! -f $PATH1 ]
then
echo "error: DATASET_PATH=$PATH1 is not a file"
exit 1
fi
ulimit -u unlimited
export DEVICE_NUM=1
export DEVICE_ID=1
export RANK_ID=0
export RANK_SIZE=1
if [ -d "train" ];
then
rm -rf ./train
fi
mkdir ./train
cp ../*.py ./train
cp *.sh ./train
cp -r ../src ./train
cd ./train || exit
echo "start training for device $DEVICE_ID"
env > env.log
python train.py --device_id=$DEVICE_ID --mindrecord_file=$PATH1 --is_distributed=0 &> log &
cd ..

View File

@ -0,0 +1,178 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#" ============================================================================
"""
CRNN-Seq2Seq-OCR model.
"""
import numpy as np
import mindspore as ms
import mindspore.nn as nn
import mindspore.ops.operations as P
import mindspore.common.dtype as mstype
from mindspore import context, Tensor
from mindspore.nn.loss.loss import _Loss
from mindspore.ops import composite as C
from mindspore.ops import functional as F
from mindspore.context import ParallelMode
from mindspore.communication.management import get_group_size
from mindspore.parallel._auto_parallel_context import auto_parallel_context
from src.seq2seq import Encoder, Decoder
class NLLLoss(_Loss):
def __init__(self, reduction='mean'):
super(NLLLoss, self).__init__(reduction)
self.one_hot = P.OneHot()
self.reduce_sum = P.ReduceSum()
def construct(self, logits, label):
label_one_hot = self.one_hot(label, F.shape(logits)[-1], F.scalar_to_array(1.0), F.scalar_to_array(0.0))
loss = self.reduce_sum(-1.0 * logits * label_one_hot, (1,))
return self.get_loss(loss)
class AttentionOCRInfer(nn.Cell):
def __init__(self, batch_size, conv_out_dim, encoder_hidden_size, decoder_hidden_size,
decoder_output_size, max_length, dropout_p=0.1):
super(AttentionOCRInfer, self).__init__()
self.encoder = Encoder(batch_size=batch_size,
conv_out_dim=conv_out_dim,
hidden_size=encoder_hidden_size)
self.decoder = Decoder(hidden_size=decoder_hidden_size,
output_size=decoder_output_size,
max_length=max_length,
dropout_p=dropout_p)
def construct(self, img, decoder_input, decoder_hidden):
'''
get token output
'''
encoder_outputs = self.encoder(img)
decoder_output, decoder_hidden, decoder_attention = self.decoder(
decoder_input, decoder_hidden, encoder_outputs)
return decoder_output, decoder_hidden, decoder_attention
class AttentionOCR(nn.Cell):
def __init__(self, batch_size, conv_out_dim, encoder_hidden_size, decoder_hidden_size,
decoder_output_size, max_length, dropout_p=0.1):
super(AttentionOCR, self).__init__()
self.encoder = Encoder(batch_size=batch_size,
conv_out_dim=conv_out_dim,
hidden_size=encoder_hidden_size)
self.decoder = Decoder(hidden_size=decoder_hidden_size,
output_size=decoder_output_size,
max_length=max_length,
dropout_p=dropout_p)
self.init_decoder_hidden = Tensor(np.zeros((1, batch_size, decoder_hidden_size),
dtype=np.float16), mstype.float16)
self.shape = P.Shape()
self.split = P.Split(axis=1, output_num=max_length)
self.concat = P.Concat()
self.expand_dims = P.ExpandDims()
self.argmax = P.Argmax()
self.select = P.Select()
def construct(self, img, decoder_inputs, decoder_targets, teacher_force):
encoder_outputs = self.encoder(img)
_, text_len = self.shape(decoder_inputs)
decoder_outputs = ()
decoder_input_tuple = self.split(decoder_inputs)
decoder_target_tuple = self.split(decoder_targets)
decoder_input = decoder_input_tuple[0]
decoder_hidden = self.init_decoder_hidden
for i in range(text_len):
decoder_output, decoder_hidden, _ = self.decoder(decoder_input, decoder_hidden, encoder_outputs)
topi = self.argmax(decoder_output)
decoder_input_top = self.expand_dims(topi, 1)
decoder_input = self.select(teacher_force, decoder_target_tuple[i], decoder_input_top)
decoder_output = self.expand_dims(decoder_output, 0)
decoder_outputs += (decoder_output,)
outputs = self.concat(decoder_outputs)
return outputs
class AttentionOCRWithLossCell(nn.Cell):
"""AttentionOCR with Loss"""
def __init__(self, network, max_length):
super(AttentionOCRWithLossCell, self).__init__()
self.network = network
self.loss = NLLLoss()
self.shape = P.Shape()
self.add = P.AddN()
self.mean = P.ReduceMean()
self.split = P.Split(axis=0, output_num=max_length)
self.squeeze = P.Squeeze()
self.cast = P.Cast()
def construct(self, img, decoder_inputs, decoder_targets, teacher_force):
decoder_outputs = self.network(img, decoder_inputs, decoder_targets, teacher_force)
decoder_outputs = self.cast(decoder_outputs, mstype.float32)
_, text_len = self.shape(decoder_targets)
loss_total = ()
decoder_output_tuple = self.split(decoder_outputs)
for i in range(text_len):
loss = self.loss(self.squeeze(decoder_output_tuple[i]), decoder_targets[:, i])
loss = self.mean(loss)
loss_total += (loss,)
loss_output = self.add(loss_total)
return loss_output
grad_scale = C.MultitypeFuncGraph("grad_scale")
@grad_scale.register("Tensor", "Tensor")
def tensor_grad_scale(scale, grad):
return grad * P.Reciprocal()(scale)
class TrainingWrapper(nn.Cell):
def __init__(self, network, optimizer, sens=1.0):
super(TrainingWrapper, self).__init__(auto_prefix=False)
self.network = network
self.network.set_grad()
self.weights = ms.ParameterTuple(network.trainable_params())
self.optimizer = optimizer
self.grad = C.GradOperation(get_by_list=True, sens_param=True)
self.sens = sens
self.reducer_flag = False
self.grad_reducer = None
# Set parallel_mode
self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]:
self.reducer_flag = True
if self.reducer_flag:
mean = context.get_auto_parallel_context("gradients_mean")
if auto_parallel_context().get_device_num_is_set():
degree = context.get_auto_parallel_context("device_num")
else:
degree = get_group_size()
self.grad_reducer = nn.DistributedGradReducer(optimizer.parameters, mean, degree)
self.hyper_map = C.HyperMap()
def construct(self, *args):
weights = self.weights
loss = self.network(*args)
sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens)
grads = self.grad(self.network, weights)(*args, sens)
if self.reducer_flag:
grads = self.grad_reducer(grads)
return F.depend(loss, self.optimizer(grads))

View File

@ -0,0 +1,195 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#" ============================================================================
"""
CRN-Seq2Seq-OCR CNN model.
"""
import math
import numpy as np
import mindspore.nn as nn
from mindspore.ops import operations as P
from mindspore.common.tensor import Tensor
def calculate_gain(nonlinearity, param=None):
"""calculate_gain"""
linear_fns = ['linear', 'conv1d', 'conv2d', 'conv3d', 'conv_transpose1d', 'conv_transpose2d', 'conv_transpose3d']
res = 0
if nonlinearity in linear_fns or nonlinearity == 'sigmoid':
res = 1
elif nonlinearity == 'tanh':
res = 5.0 / 3
elif nonlinearity == 'relu':
res = math.sqrt(2.0)
elif nonlinearity == 'leaky_relu':
if param is None:
negative_slope = 0.01
elif not isinstance(param, bool) and isinstance(param, int) or isinstance(param, float):
negative_slope = param
else:
raise ValueError("negative_slope {} not a valid number".format(param))
res = math.sqrt(2.0 / (1 + negative_slope ** 2))
else:
raise ValueError("Unsupported nonlinearity {}".format(nonlinearity))
return res
def _calculate_fan_in_and_fan_out(tensor):
"""_calculate_fan_in_and_fan_out"""
dimensions = len(tensor)
if dimensions < 2:
raise ValueError("Fan in and fan out can not be computed for tensor with fewer than 2 dimensions")
if dimensions == 2:
fan_in = tensor[1]
fan_out = tensor[0]
else:
num_input_fmaps = tensor[1]
num_output_fmaps = tensor[0]
receptive_field_size = 1
if dimensions > 2:
receptive_field_size = tensor[2] * tensor[3]
fan_in = num_input_fmaps * receptive_field_size
fan_out = num_output_fmaps * receptive_field_size
return fan_in, fan_out
def _calculate_correct_fan(tensor, mode):
mode = mode.lower()
valid_modes = ['fan_in', 'fan_out']
if mode not in valid_modes:
raise ValueError("Mode {} not supported, please use one of {}".format(mode, valid_modes))
fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
return fan_in if mode == 'fan_in' else fan_out
def kaiming_normal(inputs_shape, gain_param=0, mode='fan_in', nonlinearity='leaky_relu'):
fan = _calculate_correct_fan(inputs_shape, mode)
gain = calculate_gain(nonlinearity, gain_param)
std = gain / math.sqrt(fan)
return np.random.normal(0, std, size=inputs_shape).astype(np.float32)
class ConvRelu(nn.Cell):
"""
Convolution Layer followed by Relu Layer
"""
def __init__(self, in_channels, out_channels, kernel_size=1, stride=1):
super(ConvRelu, self).__init__()
shape = (out_channels, in_channels, kernel_size[0], kernel_size[1])
self.conv = nn.Conv2d(in_channels,
out_channels,
kernel_size,
stride,
weight_init=Tensor(kaiming_normal(shape)))
self.relu = nn.ReLU()
def construct(self, x):
x = self.conv(x)
x = self.relu(x)
return x
class ConvBNRelu(nn.Cell):
"""
Convolution Layer followed by Batch Normalization and Relu Layer
"""
def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, pad_mode='same'):
super(ConvBNRelu, self).__init__()
shape = (out_channels, in_channels, kernel_size[0], kernel_size[1])
self.conv = nn.Conv2d(in_channels,
out_channels,
kernel_size, stride,
pad_mode=pad_mode,
weight_init=Tensor(kaiming_normal(shape)))
self.bn = nn.BatchNorm2d(out_channels)
self.relu = nn.ReLU()
def construct(self, x):
x = self.conv(x)
x = self.bn(x)
x = self.relu(x)
return x
class CNN(nn.Cell):
"""
CNN Class for OCR
"""
def __init__(self, conv_out_dim):
super(CNN, self).__init__()
self.convRelu1 = ConvRelu(3, 64, (3, 3))
self.maxpool1 = nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2))
self.convRelu2 = ConvRelu(64, 128, (3, 3))
self.maxpool2 = nn.MaxPool2d(kernel_size=(2, 1), stride=(2, 1))
self.convBNRelu1 = ConvBNRelu(128, 256, (3, 3))
self.convRelu3 = ConvRelu(256, 256, (3, 3))
self.maxpool3 = nn.MaxPool2d(kernel_size=(2, 1), stride=(2, 1))
self.convBNRelu2 = ConvBNRelu(256, 384, (3, 3))
self.convRelu4 = ConvRelu(384, 384, (3, 3))
self.maxpool4 = nn.MaxPool2d(kernel_size=(2, 1), stride=(2, 1))
self.convBNRelu3 = ConvBNRelu(384, 384, (3, 3))
self.convRelu5 = ConvRelu(384, 384, (3, 3))
self.maxpool5 = nn.MaxPool2d(kernel_size=(2, 1), stride=(2, 1))
self.convBNRelu4 = ConvBNRelu(384, 384, (3, 3))
self.convRelu6 = ConvRelu(384, 384, (3, 3))
self.maxpool6 = nn.MaxPool2d(kernel_size=(2, 1), stride=(2, 1))
self.pad = nn.Pad(paddings=((0, 0), (0, 0), (0, 0), (0, 1)))
self.convBNRelu5 = ConvBNRelu(384, conv_out_dim, (2, 2), pad_mode='valid')
self.dropout = nn.Dropout(keep_prob=0.5)
self.squeeze = P.Squeeze(2)
self.cast = P.Cast()
def construct(self, x):
x = self.convRelu1(x)
x = self.maxpool1(x)
x = self.convRelu2(x)
x = self.maxpool2(x)
x = self.convBNRelu1(x)
x = self.convRelu3(x)
x = self.maxpool3(x)
x = self.convBNRelu2(x)
x = self.convRelu4(x)
x = self.maxpool4(x)
x = self.convBNRelu3(x)
x = self.convRelu5(x)
x = self.maxpool5(x)
x = self.convBNRelu4(x)
x = self.convRelu6(x)
x = self.maxpool6(x)
x = self.pad(x)
x = self.convBNRelu5(x)
x = self.dropout(x)
x = self.squeeze(x)
return x

View File

@ -0,0 +1,61 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#" ============================================================================
"""Config parameters for CRNN-Seq2Seq-OCR model."""
from easydict import EasyDict as ed
config = ed({
# dataset-related
"mindrecord_dir": "",
"data_root": "",
"annotation_file": "",
"val_data_root": "",
"val_annotation_file": "",
"data_json": "",
"characters_dictionary": {"pad_id": 0, "go_id": 1, "eos_id": 2, "unk_id": 3},
"labels_not_use": [u'%#<23>?%', u'%#背景#%', u'%#不识<E4B88D>?%', u'#%不识<EFBFBD>?#', u'%#模糊#%', u'%#模糊#%'],
"vocab_path": "./general_chars.txt",
#model-related
"img_width": 512,
"img_height": 128,
"channel_size": 3,
"conv_out_dim": 384,
"encoder_hidden_size": 128,
"decoder_hidden_size": 128,
"decoder_output_size": 10000, # vocab_size is the decoder_output_size, characters_class+1, last 9999 is the space
"dropout_p": 0.1,
"max_length": 64,
"attn_num_layers": 1,
"teacher_force_ratio": 0.5,
#optimizer-related
"lr": 0.0008,
"adam_beta1": 0.5,
"adam_beta2": 0.999,
"loss_scale": 1024,
#train-related
"batch_size": 32,
"num_epochs": 20,
"keep_checkpoint_max": 20,
#eval-related
"eval_batch_size": 32
})

View File

@ -0,0 +1,245 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Create FSNS MindRecord files."""
import os
import numpy as np
from mindspore.mindrecord import FileWriter
from src.config import config
from src.utils import initialize_vocabulary
def serialize_annotation(img_path, lex, vocab):
go_id = config.characters_dictionary.get("go_id")
eos_id = config.characters_dictionary.get("eos_id")
word = [go_id]
for special_label in config.labels_not_use:
if lex == special_label:
if config.print_no_train_label:
print("label in for image: %s is special label, related label is: %s, skip ..." % (img_path, lex))
return None
for c in lex:
if c not in vocab:
return None
c_idx = vocab.get(c)
word.append(c_idx)
word.append(eos_id)
word = np.array(word, dtype=np.int32)
return word
def create_fsns_label(image_dir, anno_file_dirs):
"""Get image path and annotation."""
if not os.path.isdir(image_dir):
raise ValueError(f'Cannot find {image_dir} dataset path.')
image_files_dict = {}
image_anno_dict = {}
images = []
img_id = 0
for anno_file_dir in anno_file_dirs:
anno_file = open(anno_file_dir, 'r').readlines()
for line in anno_file:
file_name = line.split('\t')[0]
labels = line.split('\t')[1].split('\n')[0]
image_path = os.path.join(image_dir, file_name)
if not os.path.isfile(image_path):
print(f'Cannot find image {image_path} according to annotations.')
continue
if labels:
images.append(img_id)
image_files_dict[img_id] = image_path
image_anno_dict[img_id] = labels
img_id += 1
return images, image_files_dict, image_anno_dict
def fsns_train_data_to_mindrecord(mindrecord_dir, prefix="data_ocr.mindrecord", file_num=8):
anno_file_dirs = [config.train_annotation_file]
images, image_path_dict, image_anno_dict = create_fsns_label(image_dir=config.data_root,
anno_file_dirs=anno_file_dirs)
vocab, _ = initialize_vocabulary(config.vocab_path)
data_schema = {"image": {"type": "bytes"},
"label": {"type": "int32", "shape": [-1]},
"decoder_input": {"type": "int32", "shape": [-1]},
"decoder_mask": {"type": "int32", "shape": [-1]},
"decoder_target": {"type": "int32", "shape": [-1]},
"annotation": {"type": "string"}}
mindrecord_path = os.path.join(mindrecord_dir, prefix)
writer = FileWriter(mindrecord_path, file_num)
writer.add_schema(data_schema, "ocr")
for img_id in images:
image_path = image_path_dict[img_id]
annotation = image_anno_dict[img_id]
label_max_len = config.max_text_len
text_max_len = config.max_text_len - 2
if len(annotation) > text_max_len:
continue
label = serialize_annotation(image_path, annotation, vocab)
if label is None:
continue
label_len = len(label)
decoder_input_len = label_max_len
if label_len <= decoder_input_len:
label = np.concatenate((label, np.zeros(decoder_input_len - label_len, dtype=np.int32)))
one_mask_len = label_len - config.go_shift
target_weight = np.concatenate((np.ones(one_mask_len, dtype=np.float32),
np.zeros(decoder_input_len - one_mask_len, dtype=np.float32)))
else:
continue
decoder_input = (np.array(label).T).astype(np.int32)
target_weight = (np.array(target_weight).T).astype(np.int32)
if not len(decoder_input) == len(target_weight):
continue
target = [decoder_input[i + 1] for i in range(len(decoder_input) - 1)]
target = (np.array(target)).astype(np.int32)
with open(image_path, 'rb') as f:
img = f.read()
row = {"image": img,
"label": label,
"decoder_input": decoder_input,
"decoder_mask": target_weight,
"decoder_target": target,
"annotation": str(annotation)}
writer.write_raw_data([row])
writer.commit()
def fsns_val_data_to_mindrecord(mindrecord_dir, prefix="data_ocr.mindrecord", file_num=8):
anno_file_dirs = [config.train_annotation_file]
images, image_path_dict, image_anno_dict = create_fsns_label(image_dir=config.data_root,
anno_file_dirs=anno_file_dirs)
vocab, _ = initialize_vocabulary(config.vocab_path)
data_schema = {"image": {"type": "bytes"},
"decoder_input": {"type": "int32", "shape": [-1]},
"decoder_target": {"type": "int32", "shape": [-1]},
"annotation": {"type": "string"}}
mindrecord_path = os.path.join(mindrecord_dir, prefix)
writer = FileWriter(mindrecord_path, file_num)
writer.add_schema(data_schema, "ocr")
for img_id in images:
image_path = image_path_dict[img_id]
annotation = image_anno_dict[img_id]
label_max_len = config.max_text_len
text_max_len = config.max_text_len - 2
if len(annotation) > text_max_len:
continue
label = serialize_annotation(image_path, annotation, vocab)
if label is None:
continue
label_len = len(label)
decoder_input_len = label_max_len
if label_len <= decoder_input_len:
label = np.concatenate((label, np.zeros(decoder_input_len - label_len, dtype=np.int32)))
else:
continue
decoder_input = (np.array(label).T).astype(np.int32)
target = [decoder_input[i + 1] for i in range(len(decoder_input) - 1)]
target = (np.array(target)).astype(np.int32)
with open(image_path, 'rb') as f:
img = f.read()
row = {"image": img,
"decoder_input": decoder_input,
"decoder_target": target,
"annotation": str(annotation)}
writer.write_raw_data([row])
writer.commit()
def create_mindrecord(dataset="fsns", prefix="fsns.mindrecord", is_training=True):
print("Start creating dataset!")
if is_training:
mindrecord_dir = os.path.join(config.mindrecord_dir, "train")
mindrecord_files = [os.path.join(mindrecord_dir, prefix + "0")]
if not os.path.exists(mindrecord_files[0]):
if not os.path.isdir(mindrecord_dir):
os.makedirs(mindrecord_dir)
if dataset == "fsns":
if os.path.isdir(config.data_root):
print("Create FSNS Mindrecord files for train pipeline.")
fsns_train_data_to_mindrecord(mindrecord_dir=mindrecord_dir, prefix=prefix, file_num=8)
print("Create FSNS Mindrecord files for train pipeline Done, at {}".format(mindrecord_dir))
else:
print("{} not exits!".format(config.data_root))
else:
print("{} dataset is not defined!".format(dataset))
if not is_training:
mindrecord_dir = os.path.join(config.mindrecord_dir, "val")
mindrecord_files = [os.path.join(mindrecord_dir, prefix + "0")]
if not os.path.exists(mindrecord_files[0]):
if not os.path.isdir(mindrecord_dir):
os.makedirs(mindrecord_dir)
if dataset == "fsns":
if os.path.isdir(config.val_data_root):
print("Create FSNS Mindrecord files for val pipeline.")
fsns_val_data_to_mindrecord(mindrecord_dir=mindrecord_dir, prefix=prefix)
print("Create FSNS Mindrecord files for val pipeline Done, at {}".format(mindrecord_dir))
else:
print("{} not exits!".format(config.val_data_root))
else:
print("{} dataset is not defined!".format(dataset))
return mindrecord_files

View File

@ -0,0 +1,144 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""FSNS dataset"""
import cv2
import numpy as np
from PIL import Image
import mindspore.dataset as de
import mindspore.dataset.vision.c_transforms as C
import mindspore.dataset.vision.py_transforms as P
import mindspore.dataset.transforms.c_transforms as ops
import mindspore.common.dtype as mstype
from src.config import config
class AugmentationOps():
def __init__(self, min_area_ratio=0.8, aspect_ratio_range=(0.8, 1.2), brightness=32./255.,
contrast=0.5, saturation=0.5, hue=0.2, img_tile_shape=(150, 150)):
self.min_area_ratio = min_area_ratio
self.aspect_ratio_range = aspect_ratio_range
self.img_tile_shape = img_tile_shape
self.random_image_distortion_ops = P.RandomColorAdjust(brightness=brightness,
contrast=contrast,
saturation=saturation,
hue=hue)
def __call__(self, img):
img_h = self.img_tile_shape[0]
img_w = self.img_tile_shape[1]
img_new = np.zeros([128, 512, 3])
for i in range(4):
img_tile = img[:, (i*150):((i+1)*150), :]
# Random crop cut from the street sign image, resized to the same size.
# Assures that the crop covers at least 0.8 area of the input image.
# Aspect ratio of cropped image is within [0.8,1.2] range.
h = img_h + 1
w = img_w + 1
while (w >= img_w or h >= img_h):
aspect_ratio = np.random.uniform(self.aspect_ratio_range[0],
self.aspect_ratio_range[1])
h_low = np.ceil(np.sqrt(self.min_area_ratio * img_h * img_w / aspect_ratio))
h_high = np.floor(np.sqrt(img_h * img_w / aspect_ratio))
h = np.random.randint(h_low, h_high)
w = int(h * aspect_ratio)
y = np.random.randint(img_w - w)
x = np.random.randint(img_h - h)
img_tile = img_tile[x:(x+h), y:(y+w), :]
# Randomly chooses one of the 4 interpolation resize methods.
interpolation = np.random.choice([cv2.INTER_LINEAR,
cv2.INTER_CUBIC,
cv2.INTER_AREA,
cv2.INTER_NEAREST])
img_tile = cv2.resize(img_tile, (128, 128), interpolation=interpolation)
# Random color distortion ops.
img_tile_pil = Image.fromarray(img_tile)
img_tile_pil = self.random_image_distortion_ops(img_tile_pil)
img_tile = np.array(img_tile_pil)
img_new[:, (i*128):((i+1)*128), :] = img_tile
img_new = 2 * (img_new / 255.) - 1
return img_new
class ImageResizeWithRescale():
def __init__(self, standard_img_height, standard_img_width, channel_size=3):
self.standard_img_height = standard_img_height
self.standard_img_width = standard_img_width
self.channel_size = channel_size
def __call__(self, img):
img = cv2.resize(img, (self.standard_img_width, self.standard_img_height))
img = 2 * (img / 255.) - 1
return img
def random_teacher_force(images, source_ids, target_ids):
teacher_force = np.random.random() < config.teacher_force_ratio
teacher_force_array = np.array([teacher_force], dtype=bool)
return images, source_ids, target_ids, teacher_force_array
def create_ocr_train_dataset(mindrecord_file, batch_size=32, rank_size=1, rank_id=0,
is_training=True, num_parallel_workers=4, use_multiprocessing=True):
ds = de.MindDataset(mindrecord_file,
columns_list=["image", "decoder_input", "decoder_target"],
num_shards=rank_size,
shard_id=rank_id,
num_parallel_workers=num_parallel_workers,
shuffle=is_training)
aug_ops = AugmentationOps()
transforms = [C.Decode(),
aug_ops,
C.HWC2CHW()]
ds = ds.map(operations=transforms, input_columns=["image"], python_multiprocessing=use_multiprocessing,
num_parallel_workers=num_parallel_workers)
ds = ds.map(operations=ops.PadEnd([config.max_length], 0), input_columns=["decoder_target"])
ds = ds.map(operations=random_teacher_force, input_columns=["image", "decoder_input", "decoder_target"],
output_columns=["image", "decoder_input", "decoder_target", "teacher_force"],
column_order=["image", "decoder_input", "decoder_target", "teacher_force"])
type_cast_op_bool = ops.TypeCast(mstype.bool_)
ds = ds.map(operations=type_cast_op_bool, input_columns="teacher_force")
print("Train dataset size= %s" % (int(ds.get_dataset_size())))
ds = ds.batch(batch_size, drop_remainder=True)
return ds
def create_ocr_val_dataset(mindrecord_file, batch_size=32, rank_size=1, rank_id=0,
num_parallel_workers=4, use_multiprocessing=True):
ds = de.MindDataset(mindrecord_file,
columns_list=["image", "annotation", "decoder_input", "decoder_target"],
num_shards=rank_size,
shard_id=rank_id,
num_parallel_workers=num_parallel_workers,
shuffle=False)
resize_rescale_op = ImageResizeWithRescale(standard_img_height=128, standard_img_width=512)
transforms = [C.Decode(),
resize_rescale_op,
C.HWC2CHW()]
ds = ds.map(operations=transforms, input_columns=["image"], python_multiprocessing=use_multiprocessing,
num_parallel_workers=num_parallel_workers)
ds = ds.map(operations=ops.PadEnd([config.max_length], 0), input_columns=["decoder_target"],
python_multiprocessing=use_multiprocessing, num_parallel_workers=8)
ds = ds.map(operations=ops.PadEnd([config.max_length], 0), input_columns=["decoder_input"],
python_multiprocessing=use_multiprocessing, num_parallel_workers=8)
ds = ds.batch(batch_size, drop_remainder=True)
print("Val dataset size= %s" % (str(int(ds.get_dataset_size())*batch_size)))
return ds

View File

@ -0,0 +1,55 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#" ============================================================================
"""
GRU cell
"""
import mindspore.nn as nn
import mindspore.ops.operations as P
import mindspore.common.dtype as mstype
from src.weight_init import gru_default_state
class GRU(nn.Cell):
'''
GRU model
Args:
input_size: The number of expected features in the input
hidden_size: The number of features in the hidden state
'''
def __init__(self, input_size, hidden_size):
super(GRU, self).__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.weight_i, self.weight_h, self.bias_i, self.bias_h = gru_default_state(self.input_size, self.hidden_size)
self.rnn = P.DynamicGRUV2()
self.cast = P.Cast()
def construct(self, x, h):
'''
GRU construction
Args:
x(Tensor): GRU input
h(Tensor): GRU hidden state
Returns:
output(Tensor): rnn output
hidden(Tensor): hidden state
'''
x = self.cast(x, mstype.float16)
h = self.cast(h, mstype.float16)
y1, h1, _, _, _, _ = self.rnn(x, self.weight_i, self.weight_h, self.bias_i, self.bias_h, None, h)
return y1, h1

View File

@ -0,0 +1,80 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Custom Logger."""
import os
import sys
import logging
from datetime import datetime
class LOGGER(logging.Logger):
"""
Logger.
Args:
logger_name: String. Logger name.
rank: Integer. Rank id.
"""
def __init__(self, logger_name, rank=0):
super(LOGGER, self).__init__(logger_name)
self.rank = rank
if rank % 8 == 0:
console = logging.StreamHandler(sys.stdout)
console.setLevel(logging.INFO)
formatter = logging.Formatter('%(asctime)s:%(levelname)s:%(message)s')
console.setFormatter(formatter)
self.addHandler(console)
def setup_logging_file(self, log_dir, rank=0):
"""Setup logging file."""
self.rank = rank
if not os.path.exists(log_dir):
os.makedirs(log_dir, exist_ok=True)
log_name = datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S') + '_rank_{}.log'.format(rank)
self.log_fn = os.path.join(log_dir, log_name)
fh = logging.FileHandler(self.log_fn)
fh.setLevel(logging.INFO)
formatter = logging.Formatter('%(asctime)s:%(levelname)s:%(message)s')
fh.setFormatter(formatter)
self.addHandler(fh)
def info(self, msg, *args, **kwargs):
if self.isEnabledFor(logging.INFO):
self._log(logging.INFO, msg, args, **kwargs)
def save_args(self, args):
self.info('Args:')
args_dict = vars(args)
for key in args_dict.keys():
self.info('--> %s: %s', key, args_dict[key])
self.info('')
def important_info(self, msg, *args, **kwargs):
if self.isEnabledFor(logging.INFO) and self.rank == 0:
line_width = 2
important_msg = '\n'
important_msg += ('*'*70 + '\n')*line_width
important_msg += ('*'*line_width + '\n')*2
important_msg += '*'*line_width + ' '*8 + msg + '\n'
important_msg += ('*'*line_width + '\n')*2
important_msg += ('*'*70 + '\n')*line_width
self.info(important_msg, *args, **kwargs)
def get_logger(path, rank):
"""Get Logger."""
logger = LOGGER('crnn-seq2seq-ocr', rank)
logger.setup_logging_file(path, rank)
return logger

View File

@ -0,0 +1,196 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#" ============================================================================
"""lstm"""
import math
import numpy as np
from mindspore import nn, context, Tensor, Parameter, ParameterTuple
import mindspore.common.dtype as mstype
from mindspore.common.initializer import initializer
from mindspore.ops.primitive import constexpr
from mindspore.ops import operations as P
@constexpr
def _create_sequence_length(shape):
num_step, batch_size, _ = shape
sequence_length = Tensor(np.ones(batch_size, np.int32) * num_step, mstype.int32)
return sequence_length
class LSTM(nn.Cell):
"""
Stacked LSTM (Long Short-Term Memory) layers.
Args:
input_size (int): Number of features of input.
hidden_size (int): Number of features of hidden layer.
num_layers (int): Number of layers of stacked LSTM . Default: 1.
has_bias (bool): Whether the cell has bias `b_ih` and `b_hh`. Default: True.
batch_first (bool): Specifies whether the first dimension of input is batch_size. Default: False.
dropout (float, int): If not 0, append `Dropout` layer on the outputs of each
LSTM layer except the last layer. Default 0. The range of dropout is [0.0, 1.0].
bidirectional (bool): Specifies whether it is a bidirectional LSTM. Default: False.
Inputs:
- **input** (Tensor) - Tensor of shape (seq_len, batch_size, `input_size`) or
(batch_size, seq_len, `input_size`).
- **hx** (tuple) - A tuple of two Tensors (h_0, c_0) both of data type mindspore.float32 or
mindspore.float16 and shape (num_directions * `num_layers`, batch_size, `hidden_size`).
Data type of `hx` must be the same as `input`.
Outputs:
Tuple, a tuple contains (`output`, (`h_n`, `c_n`)).
- **output** (Tensor) - Tensor of shape (seq_len, batch_size, num_directions * `hidden_size`).
- **hx_n** (tuple) - A tuple of two Tensor (h_n, c_n) both of shape
(num_directions * `num_layers`, batch_size, `hidden_size`).
"""
def __init__(self,
input_size,
hidden_size,
num_layers=1,
has_bias=True,
batch_first=False,
dropout=0,
bidirectional=False):
super(LSTM, self).__init__()
self.is_ascend = context.get_context("device_target") == "Ascend"
self.batch_first = batch_first
self.transpose = P.Transpose()
self.num_layers = num_layers
self.bidirectional = bidirectional
self.dropout = dropout
self.lstm = P.LSTM(input_size=input_size,
hidden_size=hidden_size,
num_layers=num_layers,
has_bias=has_bias,
bidirectional=bidirectional,
dropout=float(dropout))
weight_size = 0
gate_size = 4 * hidden_size
stdv = 1 / math.sqrt(hidden_size)
num_directions = 2 if bidirectional else 1
if self.is_ascend:
self.reverse_seq = P.ReverseSequence(batch_dim=1, seq_dim=0)
self.concat = P.Concat(axis=0)
self.concat_2dim = P.Concat(axis=2)
self.cast = P.Cast()
self.shape = P.Shape()
if dropout < 0 or dropout > 1:
raise ValueError("For LSTM, dropout must be a number in range [0, 1], but got {}".format(dropout))
if dropout == 1:
self.dropout_op = P.ZerosLike()
else:
self.dropout_op = nn.Dropout(float(1 - dropout))
b0 = np.zeros(gate_size, dtype=np.float32)
self.w_list = []
self.b_list = []
self.rnns_fw = P.DynamicRNN(forget_bias=0.0)
self.rnns_bw = P.DynamicRNN(forget_bias=0.0)
for layer in range(num_layers):
w_shape = input_size if layer == 0 else (num_directions * hidden_size)
w_np = np.random.uniform(-stdv, stdv, (w_shape + hidden_size, gate_size)).astype(np.float32)
self.w_list.append(Parameter(
initializer(Tensor(w_np), [w_shape + hidden_size, gate_size]), name='weight_fw' + str(layer)))
if has_bias:
b_np = np.random.uniform(-stdv, stdv, gate_size).astype(np.float32)
self.b_list.append(Parameter(initializer(Tensor(b_np), [gate_size]), name='bias_fw' + str(layer)))
else:
self.b_list.append(Parameter(initializer(Tensor(b0), [gate_size]), name='bias_fw' + str(layer)))
if bidirectional:
w_bw_np = np.random.uniform(-stdv, stdv, (w_shape + hidden_size, gate_size)).astype(np.float32)
self.w_list.append(Parameter(initializer(Tensor(w_bw_np), [w_shape + hidden_size, gate_size]),
name='weight_bw' + str(layer)))
b_bw_np = np.random.uniform(-stdv, stdv, (4 * hidden_size)).astype(np.float32) if has_bias else b0
self.b_list.append(Parameter(initializer(Tensor(b_bw_np), [gate_size]),
name='bias_bw' + str(layer)))
self.w_list = ParameterTuple(self.w_list)
self.b_list = ParameterTuple(self.b_list)
else:
for layer in range(num_layers):
input_layer_size = input_size if layer == 0 else hidden_size * num_directions
increment_size = gate_size * input_layer_size
increment_size += gate_size * hidden_size
if has_bias:
increment_size += 2 * gate_size
weight_size += increment_size * num_directions
w_np = np.random.uniform(-stdv, stdv, (weight_size, 1, 1)).astype(np.float32)
self.weight = Parameter(initializer(Tensor(w_np), [weight_size, 1, 1]), name='weight')
def _stacked_bi_dynamic_rnn(self, x, init_h, init_c, weight, bias):
"""stacked bidirectional dynamic_rnn"""
x_shape = self.shape(x)
sequence_length = _create_sequence_length(x_shape)
pre_layer = x
hn = ()
cn = ()
output = x
for i in range(self.num_layers):
offset = i * 2
weight_fw, weight_bw = weight[offset], weight[offset + 1]
bias_fw, bias_bw = bias[offset], bias[offset + 1]
init_h_fw, init_h_bw = init_h[offset:offset + 1, :, :], init_h[offset + 1:offset + 2, :, :]
init_c_fw, init_c_bw = init_c[offset:offset + 1, :, :], init_c[offset + 1:offset + 2, :, :]
bw_x = self.reverse_seq(pre_layer, sequence_length)
y, h, c, _, _, _, _, _ = self.rnns_fw(pre_layer, weight_fw, bias_fw, None, init_h_fw, init_c_fw)
y_bw, h_bw, c_bw, _, _, _, _, _ = self.rnns_bw(bw_x, weight_bw, bias_bw, None, init_h_bw, init_c_bw)
y_bw = self.reverse_seq(y_bw, sequence_length)
output = self.concat_2dim((y, y_bw))
pre_layer = self.dropout_op(output) if self.dropout else output
hn += (h[-1:, :, :],)
hn += (h_bw[-1:, :, :],)
cn += (c[-1:, :, :],)
cn += (c_bw[-1:, :, :],)
status_h = self.concat(hn)
status_c = self.concat(cn)
return output, status_h, status_c
def _stacked_dynamic_rnn(self, x, init_h, init_c, weight, bias):
"""stacked mutil_layer dynamic_rnn"""
pre_layer = x
hn = ()
cn = ()
y = 0
for i in range(self.num_layers):
weight_fw, bias_bw = weight[i], bias[i]
init_h_fw, init_c_bw = init_h[i:i + 1, :, :], init_c[i:i + 1, :, :]
y, h, c, _, _, _, _, _ = self.rnns_fw(pre_layer, weight_fw, bias_bw, None, init_h_fw, init_c_bw)
pre_layer = self.dropout_op(y) if self.dropout else y
hn += (h[-1:, :, :],)
cn += (c[-1:, :, :],)
status_h = self.concat(hn)
status_c = self.concat(cn)
return y, status_h, status_c
def construct(self, x, hx):
if self.batch_first:
x = self.transpose(x, (1, 0, 2))
h, c = hx
if self.is_ascend:
x = self.cast(x, mstype.float16)
h = self.cast(h, mstype.float16)
c = self.cast(c, mstype.float16)
if self.bidirectional:
x, h, c = self._stacked_bi_dynamic_rnn(x, h, c, self.w_list, self.b_list)
else:
x, h, c = self._stacked_dynamic_rnn(x, h, c, self.w_list, self.b_list)
else:
x, h, c, _, _ = self.lstm(x, h, c, self.weight)
if self.batch_first:
x = self.transpose(x, (1, 0, 2))
return x, (h, c)

View File

@ -0,0 +1,165 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#" ============================================================================
"""
Seq2Seq_OCR model.
"""
import mindspore.nn as nn
import mindspore.ops.operations as P
import mindspore.common.dtype as mstype
from src.cnn import CNN
from src.gru import GRU
from src.lstm import LSTM
from src.weight_init import lstm_default_state
class BidirectionalLSTM(nn.Cell):
"""Bidirectional LSTM with a Dense layer
Args:
batch_size(int): batch size of input data
input_size(int): Size of time sequence
hidden_size(int): the hidden size of LSTM layers
output_size(int): the output size of the dense layer
"""
def __init__(self, batch_size, input_size, hidden_size, output_size):
super(BidirectionalLSTM, self).__init__()
self.rnn = LSTM(input_size=input_size, hidden_size=hidden_size, bidirectional=True).to_float(mstype.float16)
self.h, self.c = lstm_default_state(batch_size, hidden_size, bidirectional=True)
self.embedding = nn.Dense(hidden_size * 2, output_size).to_float(mstype.float16)
self.shape = P.Shape()
self.reshape = P.Reshape()
self.cast = P.Cast()
def construct(self, inputs):
inputs = self.cast(inputs, mstype.float16)
recurrent, _ = self.rnn(inputs, (self.h, self.c))
T, b, h = self.shape(recurrent)
t_rec = self.reshape(recurrent, (T * b, h))
output = self.embedding(t_rec)
output = self.reshape(output, (T, b, -1))
return output
class AttnDecoderRNN(nn.Cell):
"""Attention Decoder Structure with a one-layer GRU
Args:
hidden_size(int): the hidden size
output_size(int): the output size
max_length(iht): max time step of the decoder
dropout_p(float): dropout probability, default is 0.1
"""
def __init__(self, hidden_size, output_size, max_length, dropout_p=0.1):
super(AttnDecoderRNN, self).__init__()
self.hidden_size = hidden_size
self.output_size = output_size
self.dropout_p = dropout_p
self.max_length = max_length
self.embedding = nn.Embedding(self.output_size, self.hidden_size)
self.attn = nn.Dense(in_channels=self.hidden_size * 2, out_channels=self.max_length).to_float(mstype.float16)
self.attn_combine = nn.Dense(in_channels=self.hidden_size * 2,
out_channels=self.hidden_size).to_float(mstype.float16)
self.dropout = nn.Dropout(keep_prob=1.0 - self.dropout_p)
self.gru = GRU(hidden_size, hidden_size).to_float(mstype.float16)
self.out = nn.Dense(in_channels=self.hidden_size, out_channels=self.output_size).to_float(mstype.float16)
self.transpose = P.Transpose()
self.concat = P.Concat(axis=2)
self.concat1 = P.Concat(axis=1)
self.softmax = P.Softmax(axis=1)
self.relu = P.ReLU()
self.log_softmax = P.LogSoftmax(axis=1)
self.bmm = P.BatchMatMul()
self.unsqueeze = P.ExpandDims()
self.squeeze = P.Squeeze(1)
self.squeeze1 = P.Squeeze(0)
self.cast = P.Cast()
def construct(self, inputs, hidden, encoder_outputs):
embedded = self.embedding(inputs)
embedded = self.transpose(embedded, (1, 0, 2))
embedded = self.dropout(embedded)
embedded = self.cast(embedded, mstype.float16)
embedded_concat = self.concat((embedded, hidden))
embedded_concat = self.squeeze1(embedded_concat)
attn_weights = self.softmax(self.attn(embedded_concat))
attn_weights = self.unsqueeze(attn_weights, 1)
perm_encoder_outputs = self.transpose(encoder_outputs, (1, 0, 2))
attn_applied = self.bmm(attn_weights, perm_encoder_outputs)
attn_applied = self.squeeze(attn_applied)
embedded_squeeze = self.squeeze1(embedded)
output = self.concat1((embedded_squeeze, attn_applied))
output = self.attn_combine(output)
output = self.unsqueeze(output, 0)
output = self.relu(output)
gru_hidden = self.squeeze1(hidden)
output, hidden, _, _, _, _ = self.gru(output, gru_hidden)
output = self.squeeze1(output)
output = self.log_softmax(self.out(output))
return output, hidden, attn_weights
class Encoder(nn.Cell):
"""Encoder with a CNN and two BidirectionalLSTM layers
Args:
batch_size(int): batch size of input data
conv_out_dim(int): the output dimension of the cnn layer
hidden_size(int): the hidden size of LSTM layers
"""
def __init__(self, batch_size, conv_out_dim, hidden_size):
super(Encoder, self).__init__()
self.cnn = CNN(int(conv_out_dim/4))
self.lstm1 = BidirectionalLSTM(batch_size, conv_out_dim, hidden_size, hidden_size).to_float(mstype.float16)
self.lstm2 = BidirectionalLSTM(batch_size, hidden_size, hidden_size, hidden_size).to_float(mstype.float16)
self.transpose = P.Transpose()
self.cast = P.Cast()
self.split = P.Split(axis=3, output_num=4)
self.concat = P.Concat(axis=1)
def construct(self, inputs):
inputs = self.cast(inputs, mstype.float32)
(x1, x2, x3, x4) = self.split(inputs)
conv1 = self.cnn(x1)
conv2 = self.cnn(x2)
conv3 = self.cnn(x3)
conv4 = self.cnn(x4)
conv = self.concat((conv1, conv2, conv3, conv4))
conv = self.transpose(conv, (2, 0, 1))
output = self.lstm1(conv)
output = self.lstm2(output)
return output
class Decoder(nn.Cell):
"""Decoder
Args:
hidden_size(int): the hidden size
output_size(int): the output size
max_length(iht): max time step of the decoder
dropout_p(float): dropout probability, default is 0.1
"""
def __init__(self, hidden_size, output_size, max_length, dropout_p=0.1):
super(Decoder, self).__init__()
self.decoder = AttnDecoderRNN(hidden_size, output_size, max_length, dropout_p)
def construct(self, inputs, hidden, encoder_outputs):
return self.decoder(inputs, hidden, encoder_outputs)

View File

@ -0,0 +1,51 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Util class or function."""
from __future__ import absolute_import, division, print_function, unicode_literals
import os
import codecs
import logging
def initialize_vocabulary(vocabulary_path):
"""
initialize vocabulary from file.
assume the vocabulary is stored one-item-per-line
"""
characters_class = 9999
if os.path.exists(vocabulary_path):
rev_vocab = []
with codecs.open(vocabulary_path, 'r', encoding='utf-8') as voc_file:
rev_vocab = [line.strip() for line in voc_file]
vocab = {x: y for (y, x) in enumerate(rev_vocab)}
reserved_char_size = characters_class - len(rev_vocab)
if reserved_char_size < 0:
raise ValueError("Number of characters in vocabulary is equal or larger than config.characters_class")
for _ in range(reserved_char_size):
rev_vocab.append('')
# put space at the last position
vocab[' '] = len(rev_vocab)
rev_vocab.append(' ')
logging.info("Initializing vocabulary ends: %s", vocabulary_path)
return vocab, rev_vocab
raise ValueError("Initializing vocabulary ends: %s" % vocabulary_path)

View File

@ -0,0 +1,41 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#" ============================================================================
"""
weights initialization
"""
import math
import numpy as np
from mindspore import Tensor, Parameter
def lstm_default_state(batch_size, hidden_size, bidirectional, num_layers=1):
"""init default input."""
num_directions = 2 if bidirectional else 1
h = Tensor(np.zeros((num_layers * num_directions, batch_size, hidden_size)).astype(np.float32))
c = Tensor(np.zeros((num_layers * num_directions, batch_size, hidden_size)).astype(np.float32))
return h, c
def gru_default_state(input_size, hidden_size):
stdv = 1 / math.sqrt(hidden_size)
weight_i = Parameter(Tensor(np.random.uniform(-stdv, stdv, (input_size, 3*hidden_size)).astype(np.float32)),
name='weight_i')
weight_h = Parameter(Tensor(np.random.uniform(-stdv, stdv, (input_size, 3*hidden_size)).astype(np.float32)),
name='weight_h')
bias_i = Parameter(Tensor(np.random.uniform(-stdv, stdv, (3*hidden_size)).astype(np.float32)),
name='bias_i')
bias_h = Parameter(Tensor(np.random.uniform(-stdv, stdv, (3*hidden_size)).astype(np.float32)),
name='bias_h')
return weight_i, weight_h, bias_i, bias_h

View File

@ -0,0 +1,158 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
CRNN-Seq2Seq-OCR train.
"""
import os
import argparse
import datetime
import mindspore.nn as nn
import mindspore.common.dtype as mstype
from mindspore.train.model import Model
from mindspore.context import ParallelMode
from mindspore.common import set_seed
from mindspore import Tensor
from mindspore import context
from mindspore.communication.management import init
from mindspore.train.callback import ModelCheckpoint
from mindspore.train.callback import CheckpointConfig, LossMonitor, TimeMonitor
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from src.config import config
from src.dataset import create_ocr_train_dataset
from src.logger import get_logger
from src.attention_ocr import AttentionOCR, AttentionOCRWithLossCell, TrainingWrapper
set_seed(1)
def parse_args():
"""Parse train arguments."""
parser = argparse.ArgumentParser('mindspore CRNN-Seq2Seq-OCR training')
# device related
parser.add_argument("--device_target", type=str, default="Ascend",
help="device where the code will be implemented.")
parser.add_argument("--device_id", type=int, default=0, help="Device id, default: 0.")
# distributed related
parser.add_argument('--is_distributed', type=int, default=0,
help='Distribute train or not, 1 for yes, 0 for no. Default: 0')
parser.add_argument('--rank_id', type=int, default=0, help='Local rank of distributed. Default: 0')
parser.add_argument('--device_num', type=int, default=1, help='World size of device. Default: 1')
#dataset related
parser.add_argument('--mindrecord_file', type=str, default='', help='Train dataset directory.')
# logging related
parser.add_argument('--log_interval', type=int, default=100, help='Logging interval steps. Default: 100')
parser.add_argument('--ckpt_path', type=str, default='outputs/', help='Checkpoint save location. Default: outputs/')
parser.add_argument('--pre_checkpoint_path', type=str, default='', help='Checkpoint save location.')
parser.add_argument('--ckpt_interval', type=int, default=None, help='Save checkpoint interval. Default: None')
parser.add_argument('--is_save_on_master', type=int, default=0,
help='Save ckpt on master or all rank, 1 for master, 0 for all ranks. Default: 0')
args, _ = parser.parse_known_args()
# logger
args.outputs_dir = os.path.join(args.ckpt_path,
datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S'))
return args
def train():
"""Train function."""
args = parse_args()
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=args.device_id)
if args.is_distributed:
rank = args.rank_id
device_num = args.device_num
context.set_auto_parallel_context(device_num=device_num,
parallel_mode=ParallelMode.DATA_PARALLEL,
gradients_mean=True)
init()
else:
rank = 0
device_num = 1
# Logger
args.logger = get_logger(args.outputs_dir, rank)
args.rank_save_ckpt_flag = 0
if args.is_save_on_master:
if rank == 0:
args.rank_save_ckpt_flag = 1
else:
args.rank_save_ckpt_flag = 1
# DATASET
dataset = create_ocr_train_dataset(args.mindrecord_file,
config.batch_size,
rank_size=device_num,
rank_id=rank)
args.steps_per_epoch = dataset.get_dataset_size()
args.logger.info('Finish loading dataset')
if not args.ckpt_interval:
args.ckpt_interval = args.steps_per_epoch
args.logger.save_args(args)
network = AttentionOCR(config.batch_size,
int(config.img_width / 4),
config.encoder_hidden_size,
config.decoder_hidden_size,
config.decoder_output_size,
config.max_length,
config.dropout_p)
if args.pre_checkpoint_path:
param_dict = load_checkpoint(args.pre_checkpoint_path)
load_param_into_net(network, param_dict)
network = AttentionOCRWithLossCell(network, config.max_length)
lr = Tensor(config.lr, mstype.float32)
opt = nn.Adam(network.trainable_params(), lr, beta1=config.adam_beta1, beta2=config.adam_beta2,
loss_scale=config.loss_scale)
network = TrainingWrapper(network, opt, sens=config.loss_scale)
args.logger.info('Finished get network')
callback = [TimeMonitor(data_size=1), LossMonitor()]
if args.rank_save_ckpt_flag:
ckpt_config = CheckpointConfig(save_checkpoint_steps=args.steps_per_epoch,
keep_checkpoint_max=config.keep_checkpoint_max)
save_ckpt_path = os.path.join(args.outputs_dir, 'ckpt_' + str(rank) + '/')
ckpt_cb = ModelCheckpoint(config=ckpt_config,
directory=save_ckpt_path,
prefix="crnn_seq2seq_ocr")
callback.append(ckpt_cb)
model = Model(network)
model.train(config.num_epochs, dataset, callbacks=callback, dataset_sink_mode=False)
args.logger.info('==========Training Done===============')
if __name__ == "__main__":
train()