!11788 Adding CRNN-Seq2Seq-OCR model to MindSpore model zoo
From: @alashkari Reviewed-by: Signed-off-by:
This commit is contained in:
commit
3952a57d85
|
@ -0,0 +1,196 @@
|
|||
# Contents
|
||||
|
||||
- [Contents](#contents)
|
||||
- [CRNN-Seq2Seq-OCR Description](#crnn-seq2seq-ocr-description)
|
||||
- [Model Architecture](#model-architecture)
|
||||
- [Dataset](#dataset)
|
||||
- [Dataset Prepare](#dataset-prepare)
|
||||
- [Environment Requirements](#environment-requirements)
|
||||
- [Quick Start](#quick-start)
|
||||
- [Script Description](#script-description)
|
||||
- [Script and Sample Code](#script-and-sample-code)
|
||||
- [Script Parameters](#script-parameters)
|
||||
- [Training Script Parameters](#training-script-parameters)
|
||||
- [Parameters Configuration](#parameters-configuration)
|
||||
- [Dataset Preparation](#dataset-preparation)
|
||||
- [Training Process](#training-process)
|
||||
- [Training](#training)
|
||||
- [Distributed Training](#distributed-training)
|
||||
- [Evaluation Process](#evaluation-process)
|
||||
- [Evaluation](#evaluation)
|
||||
- [Model Description](#model-description)
|
||||
- [Performance](#performance)
|
||||
- [Training Performance](#training-performance)
|
||||
- [Evaluation Performance](#evaluation-performance)
|
||||
|
||||
## [CRNN-Seq2Seq-OCR Description](#contents)
|
||||
|
||||
CRNN-Seq2Seq-OCR is a neural network model for image based sequence recognition tasks, such as scene text recognition and optical character recognition (OCR). Its architecture is a combination of CNN and sequence to sequence model with attention mechanism.
|
||||
|
||||
## [Model Architecture](#content)
|
||||
|
||||
CRNN-Seq2Seq-OCR applies a vgg structure to extract features from processed images, following with attention-based encoder and decoder layer, finally utilizes NLL to calculate loss. See src/attention_ocr.py for details.
|
||||
|
||||
## [Dataset](#content)
|
||||
|
||||
For training and evaluation, we use the French Street Name Signs (FSNS) released by Google as the training data, which contains approximately 1 million training images and their corresponding ground truth words.
|
||||
|
||||
## [Environment Requirements](#contents)
|
||||
|
||||
- Hardware(Ascend)
|
||||
- Prepare hardware environment with Ascend processor. If you want to try Ascend, please send the [application form](https://obs-9be7.obs.cn-east-2.myhuaweicloud.com/file/other/Ascend%20Model%20Zoo%E4%BD%93%E9%AA%8C%E8%B5%84%E6%BA%90%E7%94%B3%E8%AF%B7%E8%A1%A8.docx) to ascend@huawei.com. You will be able to have access to related resources once approved.
|
||||
- Framework
|
||||
- [MindSpore](https://gitee.com/mindspore/mindspore)
|
||||
- For more information, please check the resources below:
|
||||
- [MindSpore Tutorials](https://www.mindspore.cn/tutorial/training/en/master/index.html)
|
||||
- [MindSpore Python API](https://www.mindspore.cn/doc/api_python/en/master/index.html)
|
||||
|
||||
## [Quick Start](#contents)
|
||||
|
||||
- After the dataset is prepared, you may start running the training or the evaluation scripts as follows:
|
||||
|
||||
- Running on Ascend
|
||||
|
||||
```shell
|
||||
# distribute training example in Ascend
|
||||
$ bash run_distribute_train.sh [RANK_TABLE_FILE] [DATASET_PATH]
|
||||
|
||||
# evaluation example in Ascend
|
||||
$ bash run_eval_ascend.sh [DATASET_PATH] [CHECKPOINT_PATH]
|
||||
|
||||
# standalone training example in Ascend
|
||||
$ bash run_standalone_train.sh [DATASET_NAME] [DATASET_PATH] [PLATFORM]
|
||||
```
|
||||
|
||||
For distributed training, a hccl configuration file with JSON format needs to be created in advance.
|
||||
|
||||
Please follow the instructions in the link below:
|
||||
[hccl_tools](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/utils/hccl_tools).
|
||||
|
||||
## [Script Description](#contents)
|
||||
|
||||
### [Script and Sample Code](#contents)
|
||||
|
||||
```shell
|
||||
crnn-seq2seq-ocr
|
||||
├── README.md # Descriptions about CRNN-Seq2Seq-OCR
|
||||
├── scripts
|
||||
│ ├── run_distribute_train.sh # Launch distributed training on Ascend(8 pcs)
|
||||
│ ├── run_eval_ascend.sh # Launch Ascend evaluation
|
||||
│ └── run_standalone_train.sh # Launch standalone training on Ascend(1 pcs)
|
||||
├── src
|
||||
│ ├── attention_ocr.py # CRNN-Seq2Seq-OCR training wrapper
|
||||
│ ├── cnn.py # VGG network
|
||||
│ ├── config.py # Parameter configuration
|
||||
│ ├── create_mindrecord_files.py # Create mindrecord files from images and ground truth
|
||||
│ ├── dataset.py # Data preprocessing for training and evaluation
|
||||
│ ├── gru.py # GRU cell wrapper
|
||||
│ ├── logger.py # Logger configuration
|
||||
│ ├── lstm.py # LSTM cell wrapper
|
||||
│ ├── seq2seq.py # CRNN-Seq2Seq-OCR model structure
|
||||
│ └── utils.py # Utility functions for training and data pre-processing
|
||||
│ ├── weight_init.py # weight initialization of LSTM and GRU
|
||||
└── train.py # Training script
|
||||
├── eval.py # Evaluation Script
|
||||
```
|
||||
|
||||
### [Script Parameters](#contents)
|
||||
|
||||
#### Training Script Parameters
|
||||
|
||||
```shell
|
||||
# distributed training on Ascend
|
||||
Usage: bash run_distribute_train.sh [RANK_TABLE_FILE] [DATASET_PATH]
|
||||
|
||||
# standalone training
|
||||
Usage: bash run_standalone_train.sh [DATASET_PATH]
|
||||
```
|
||||
|
||||
#### Parameters Configuration
|
||||
|
||||
Parameters for both training and evaluation can be set in config.py.
|
||||
|
||||
### [Dataset Preparation](#contents)
|
||||
|
||||
- You may refer to "Generate dataset" in [Quick Start](#quick-start) to automatically generate a dataset, or you may choose to generate a text image dataset by yourself.
|
||||
|
||||
## [Training Process](#contents)
|
||||
|
||||
- Set options in `config.py`, including learning rate and other network hyperparameters. Click [MindSpore dataset preparation tutorial](https://www.mindspore.cn/tutorial/training/zh-CN/master/use/data_preparation.html) for more information about dataset.
|
||||
|
||||
### [Training](#contents)
|
||||
|
||||
- Run `run_standalone_train.sh` for non-distributed training of CRNN-Seq2Seq-OCR model, only support Ascend now.
|
||||
|
||||
``` bash
|
||||
bash run_standalone_train.sh [DATASET_PATH]
|
||||
```
|
||||
|
||||
#### [Distributed Training](#contents)
|
||||
|
||||
- Run `run_distribute_train.sh` for distributed training of CRNN-Seq2Seq-OCR model on Ascend.
|
||||
|
||||
``` bash
|
||||
bash run_distribute_train.sh [RANK_TABLE_FILE] [DATASET_PATH]
|
||||
```
|
||||
|
||||
Check the `train_parallel0/log.txt` and you will get outputs as following:
|
||||
|
||||
```shell
|
||||
epoch: 20 step: 4080, loss is 1.56112
|
||||
epoch: 20 step: 4081, loss is 1.6368448
|
||||
epoch time: 1559886.096 ms, per step time: 382.231 ms
|
||||
```
|
||||
|
||||
## [Evaluation Process](#contents)
|
||||
|
||||
### [Evaluation](#contents)
|
||||
|
||||
- Run `run_eval_ascend.sh` for evaluation on Ascend.
|
||||
|
||||
``` bash
|
||||
bash run_eval_ascend.sh [DATASET_PATH] [CHECKPOINT_PATH]
|
||||
```
|
||||
|
||||
Check the `eval/log` and you will get outputs as following:
|
||||
|
||||
```shell
|
||||
character precision = 0.967522
|
||||
|
||||
Annotation precision precision = 0.635204
|
||||
```
|
||||
|
||||
# Model Description
|
||||
|
||||
## Performance
|
||||
|
||||
### Evaluation Performance
|
||||
|
||||
| Parameters | Ascend |
|
||||
| -------------------------- | ----------------------------------------------------------- |
|
||||
| Model Version | V1 |
|
||||
| Resource | Ascend 910 ;CPU 2.60GHz,192cores;Memory,755G |
|
||||
| uploaded Date | 02/11/2021 (month/day/year) |
|
||||
| MindSpore Version | 1.2.0 |
|
||||
| Dataset | FSNS |
|
||||
| Training Parameters | epoch=20, batch_size=32 |
|
||||
| Optimizer | SGD |
|
||||
| Loss Function | Negative Log Likelihood |
|
||||
| Speed | 1pc: 355 ms/step; 8pcs: 385 ms/step |
|
||||
| Total time | 1pc: 64 hours; 8pcs: 9 hours |
|
||||
| Parameters (M) | 12 |
|
||||
| Scripts | [crnn_seq2seq_ocr script](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/crnn_seq2seq_ocr) |
|
||||
|
||||
### Inference Performance
|
||||
|
||||
| Parameters | Ascend |
|
||||
| ------------------- | --------------------------- |
|
||||
| Model Version | V1 |
|
||||
| Resource | Ascend 910 |
|
||||
| Uploaded Date | 02/11/2021 (month/day/year) |
|
||||
| MindSpore Version | 1.2.0 |
|
||||
| Dataset | FSNS |
|
||||
| batch_size | 32 |
|
||||
| outputs | Annotation Precision, Character Precision |
|
||||
| Accuracy | Annotation Precision=63.52%, Character Precision=96.75% |
|
||||
| Model for inference | 12M (.ckpt file) |
|
|
@ -0,0 +1,181 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
CRNN-Seq2Seq-OCR Evaluation.
|
||||
|
||||
"""
|
||||
|
||||
import os
|
||||
import codecs
|
||||
import argparse
|
||||
import numpy as np
|
||||
|
||||
import mindspore.ops.operations as P
|
||||
import mindspore.common.dtype as mstype
|
||||
|
||||
from mindspore.common import set_seed
|
||||
from mindspore import context, Tensor
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
|
||||
from src.config import config
|
||||
from src.utils import initialize_vocabulary
|
||||
from src.dataset import create_ocr_val_dataset
|
||||
from src.attention_ocr import AttentionOCRInfer
|
||||
|
||||
|
||||
set_seed(1)
|
||||
|
||||
|
||||
def text_standardization(text_in):
|
||||
"""
|
||||
replace some particular characters
|
||||
"""
|
||||
stand_text = text_in.strip()
|
||||
stand_text = ' '.join(stand_text.split())
|
||||
stand_text = stand_text.replace(u'(', u'(')
|
||||
stand_text = stand_text.replace(u')', u')')
|
||||
stand_text = stand_text.replace(u':', u':')
|
||||
return stand_text
|
||||
|
||||
|
||||
def LCS_length(str1, str2):
|
||||
"""
|
||||
calculate longest common sub-sequence between str1 and str2
|
||||
"""
|
||||
if str1 is None or str2 is None:
|
||||
return 0
|
||||
|
||||
len1 = len(str1)
|
||||
len2 = len(str2)
|
||||
if len1 == 0 or len2 == 0:
|
||||
return 0
|
||||
|
||||
lcs = [[0 for _ in range(len2 + 1)] for _ in range(2)]
|
||||
for i in range(1, len1 + 1):
|
||||
for j in range(1, len2 + 1):
|
||||
if str1[i - 1] == str2[j - 1]:
|
||||
lcs[i % 2][j] = lcs[(i - 1) % 2][j - 1] + 1
|
||||
else:
|
||||
if lcs[i % 2][j - 1] >= lcs[(i - 1) % 2][j]:
|
||||
lcs[i % 2][j] = lcs[i % 2][j - 1]
|
||||
else:
|
||||
lcs[i % 2][j] = lcs[(i - 1) % 2][j]
|
||||
|
||||
return lcs[len1 % 2][-1]
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description="CRNN-Seq2Seq-OCR Evaluation")
|
||||
parser.add_argument("--dataset_path", type=str, default="",
|
||||
help="Test Dataset path")
|
||||
parser.add_argument("--checkpoint_path", type=str, default=None,
|
||||
help="Checkpoint of AttentionOCR (Default:None).")
|
||||
parser.add_argument("--device_target", type=str, default="Ascend",
|
||||
help="device where the code will be implemented, default is Ascend")
|
||||
parser.add_argument("--device_id", type=int, default=0, help="Device id, default: 0.")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=args.device_id)
|
||||
|
||||
prefix = "fsns.mindrecord"
|
||||
mindrecord_dir = args.dataset_path
|
||||
mindrecord_file = os.path.join(mindrecord_dir, prefix + "0")
|
||||
print("mindrecord_file", mindrecord_file)
|
||||
dataset = create_ocr_val_dataset(mindrecord_file, config.eval_batch_size)
|
||||
data_loader = dataset.create_dict_iterator(num_epochs=1, output_numpy=True)
|
||||
print("Dataset creation Done!")
|
||||
|
||||
#Network
|
||||
network = AttentionOCRInfer(config.eval_batch_size,
|
||||
int(config.img_width / 4),
|
||||
config.encoder_hidden_size,
|
||||
config.decoder_hidden_size,
|
||||
config.decoder_output_size,
|
||||
config.max_length,
|
||||
config.dropout_p)
|
||||
|
||||
ckpt = load_checkpoint(args.checkpoint_path)
|
||||
load_param_into_net(network, ckpt)
|
||||
network.set_train(False)
|
||||
print("Checkpoint loading Done!")
|
||||
|
||||
vocab, rev_vocab = initialize_vocabulary(config.vocab_path)
|
||||
eos_id = config.characters_dictionary.get("eos_id")
|
||||
sos_id = config.characters_dictionary.get("go_id")
|
||||
|
||||
num_correct_char = 0
|
||||
num_total_char = 0
|
||||
num_correct_word = 0
|
||||
num_total_word = 0
|
||||
|
||||
correct_file = 'result_correct.txt'
|
||||
incorrect_file = 'result_incorrect.txt'
|
||||
|
||||
with codecs.open(correct_file, 'w', encoding='utf-8') as fp_output_correct, \
|
||||
codecs.open(incorrect_file, 'w', encoding='utf-8') as fp_output_incorrect:
|
||||
|
||||
for data in data_loader:
|
||||
images = Tensor(data["image"])
|
||||
decoder_inputs = Tensor(data["decoder_input"])
|
||||
decoder_targets = Tensor(data["decoder_target"])
|
||||
|
||||
decoder_hidden = Tensor(np.zeros((1, config.eval_batch_size, config.decoder_hidden_size),
|
||||
dtype=np.float16), mstype.float16)
|
||||
decoder_input = Tensor((np.ones((config.eval_batch_size, 1))*sos_id).astype(np.int32))
|
||||
encoder_outputs = network.encoder(images)
|
||||
batch_decoded_label = []
|
||||
|
||||
for di in range(decoder_inputs.shape[1]):
|
||||
decoder_output, decoder_hidden, _ = network.decoder(decoder_input, decoder_hidden, encoder_outputs)
|
||||
topi = P.Argmax()(decoder_output)
|
||||
ni = P.ExpandDims()(topi, 1)
|
||||
decoder_input = ni
|
||||
topi_id = topi.asnumpy()
|
||||
batch_decoded_label.append(topi_id)
|
||||
|
||||
for b in range(config.eval_batch_size):
|
||||
text = data["annotation"][b].decode("utf8")
|
||||
text = text_standardization(text)
|
||||
decoded_label = list(np.array(batch_decoded_label)[:, b])
|
||||
decoded_words = []
|
||||
for idx in decoded_label:
|
||||
if idx == eos_id:
|
||||
break
|
||||
else:
|
||||
decoded_words.append(rev_vocab[idx])
|
||||
predict = text_standardization("".join(decoded_words))
|
||||
|
||||
if predict == text:
|
||||
num_correct_word += 1
|
||||
fp_output_correct.write('\t\t' + text + '\n')
|
||||
fp_output_correct.write('\t\t' + predict + '\n\n')
|
||||
print('correctly predicted : pred: {}, gt: {}'.format(predict, text))
|
||||
|
||||
else:
|
||||
fp_output_incorrect.write('\t\t' + text + '\n')
|
||||
fp_output_incorrect.write('\t\t' + predict + '\n\n')
|
||||
print('incorrectly predicted : pred: {}, gt: {}'.format(predict, text))
|
||||
|
||||
num_total_word += 1
|
||||
num_correct_char += 2 * LCS_length(text, predict)
|
||||
num_total_char += len(text) + len(predict)
|
||||
|
||||
print('\nnum of correct characters = %d' % (num_correct_char))
|
||||
print('\nnum of total characters = %d' % (num_total_char))
|
||||
print('\nnum of correct words = %d' % (num_correct_word))
|
||||
print('\nnum of total words = %d' % (num_total_word))
|
||||
print('\ncharacter precision = %f' % (float(num_correct_char) / num_total_char))
|
||||
print('\nAnnotation precision precision = %f' % (float(num_correct_word) / num_total_word))
|
|
@ -0,0 +1,66 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
if [ $# -ne 2 ]
|
||||
then
|
||||
echo "Usage: sh run_distribute_train.sh [RANK_TABLE_FILE] [DATASET_PATH]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
get_real_path(){
|
||||
if [ "${1:0:1}" == "/" ]; then
|
||||
echo "$1"
|
||||
else
|
||||
echo "$(realpath -m $PWD/$1)"
|
||||
fi
|
||||
}
|
||||
|
||||
PATH1=$(get_real_path $1)
|
||||
echo $PATH1
|
||||
|
||||
if [ ! -f $PATH1 ]
|
||||
then
|
||||
echo "error: RANK_TABLE_FILE=$PATH1 is not a file"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
PATH2=$(get_real_path $2)
|
||||
echo $PATH2
|
||||
if [ ! -f $PATH2 ]
|
||||
then
|
||||
echo "error: PRETRAINED_PATH=$PATH2 is not a file"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
ulimit -u unlimited
|
||||
export DEVICE_NUM=8
|
||||
export RANK_SIZE=8
|
||||
export RANK_TABLE_FILE=$PATH1
|
||||
|
||||
for((i=0; i<${DEVICE_NUM}; i++))
|
||||
do
|
||||
export RANK_ID=$i
|
||||
export DEVICE_ID=$i
|
||||
rm -rf ./train_parallel$i
|
||||
mkdir ./train_parallel$i
|
||||
cp ../*.py ./train_parallel$i
|
||||
cp -r ../src ./train_parallel$i
|
||||
cd ./train_parallel$i || exit
|
||||
echo "start training for rank $RANK_ID, device $DEVICE_ID"
|
||||
env > env.log
|
||||
python train.py --device_id=$DEVICE_ID --rank_id=$RANK_ID --is_distribute=1 --device_num=$DEVICE_NUM --mindrecord_file=$PATH2 &> log &
|
||||
cd ..
|
||||
done
|
|
@ -0,0 +1,64 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
if [ $# != 2 ]
|
||||
then
|
||||
echo "Usage: sh run_eval_ascend.sh [DATASET_PATH] [CHECKPOINT_PATH]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
get_real_path(){
|
||||
if [ "${1:0:1}" == "/" ]; then
|
||||
echo "$1"
|
||||
else
|
||||
echo "$(realpath -m $PWD/$1)"
|
||||
fi
|
||||
}
|
||||
PATH1=$(get_real_path $1)
|
||||
PATH2=$(get_real_path $2)
|
||||
echo $PATH1
|
||||
echo $PATH2
|
||||
|
||||
if [ ! -d $PATH1 ]
|
||||
then
|
||||
echo "error: DATASET_PATH=$PATH1 is not a folder"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ ! -f $PATH2 ]
|
||||
then
|
||||
echo "error: CHECKPOINT_PATH=$PATH2 is not a file"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
export DEVICE_NUM=1
|
||||
export RANK_SIZE=$DEVICE_NUM
|
||||
export DEVICE_ID=0
|
||||
export RANK_ID=0
|
||||
|
||||
if [ -d "eval" ];
|
||||
then
|
||||
rm -rf ./eval
|
||||
fi
|
||||
mkdir ./eval
|
||||
cp ../*.py ./eval
|
||||
cp *.sh ./eval
|
||||
cp -r ../src ./eval
|
||||
cd ./eval || exit
|
||||
env > env.log
|
||||
echo "start eval for device $DEVICE_ID"
|
||||
python eval.py --device_target="Ascend" --device_id=$DEVICE_ID --dataset_path=$PATH1 --checkpoint_path=$PATH2 &> log &
|
||||
cd ..
|
|
@ -0,0 +1,58 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
if [ $# -ne 1 ]
|
||||
then
|
||||
echo "Usage: sh run_standalone_train_ascend.sh [DATASET_PATH]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
get_real_path(){
|
||||
if [ "${1:0:1}" == "/" ]; then
|
||||
echo "$1"
|
||||
else
|
||||
echo "$(realpath -m $PWD/$1)"
|
||||
fi
|
||||
}
|
||||
|
||||
PATH1=$(get_real_path $1)
|
||||
echo $PATH1
|
||||
|
||||
if [ ! -f $PATH1 ]
|
||||
then
|
||||
echo "error: DATASET_PATH=$PATH1 is not a file"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
ulimit -u unlimited
|
||||
export DEVICE_NUM=1
|
||||
export DEVICE_ID=1
|
||||
export RANK_ID=0
|
||||
export RANK_SIZE=1
|
||||
|
||||
if [ -d "train" ];
|
||||
then
|
||||
rm -rf ./train
|
||||
fi
|
||||
mkdir ./train
|
||||
cp ../*.py ./train
|
||||
cp *.sh ./train
|
||||
cp -r ../src ./train
|
||||
cd ./train || exit
|
||||
echo "start training for device $DEVICE_ID"
|
||||
env > env.log
|
||||
python train.py --device_id=$DEVICE_ID --mindrecord_file=$PATH1 --is_distributed=0 &> log &
|
||||
cd ..
|
|
@ -0,0 +1,178 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#" ============================================================================
|
||||
"""
|
||||
CRNN-Seq2Seq-OCR model.
|
||||
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
|
||||
import mindspore as ms
|
||||
import mindspore.nn as nn
|
||||
import mindspore.ops.operations as P
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore import context, Tensor
|
||||
from mindspore.nn.loss.loss import _Loss
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore.context import ParallelMode
|
||||
from mindspore.communication.management import get_group_size
|
||||
from mindspore.parallel._auto_parallel_context import auto_parallel_context
|
||||
|
||||
from src.seq2seq import Encoder, Decoder
|
||||
|
||||
|
||||
class NLLLoss(_Loss):
|
||||
def __init__(self, reduction='mean'):
|
||||
super(NLLLoss, self).__init__(reduction)
|
||||
self.one_hot = P.OneHot()
|
||||
self.reduce_sum = P.ReduceSum()
|
||||
|
||||
def construct(self, logits, label):
|
||||
label_one_hot = self.one_hot(label, F.shape(logits)[-1], F.scalar_to_array(1.0), F.scalar_to_array(0.0))
|
||||
loss = self.reduce_sum(-1.0 * logits * label_one_hot, (1,))
|
||||
return self.get_loss(loss)
|
||||
|
||||
|
||||
class AttentionOCRInfer(nn.Cell):
|
||||
def __init__(self, batch_size, conv_out_dim, encoder_hidden_size, decoder_hidden_size,
|
||||
decoder_output_size, max_length, dropout_p=0.1):
|
||||
super(AttentionOCRInfer, self).__init__()
|
||||
|
||||
self.encoder = Encoder(batch_size=batch_size,
|
||||
conv_out_dim=conv_out_dim,
|
||||
hidden_size=encoder_hidden_size)
|
||||
|
||||
self.decoder = Decoder(hidden_size=decoder_hidden_size,
|
||||
output_size=decoder_output_size,
|
||||
max_length=max_length,
|
||||
dropout_p=dropout_p)
|
||||
|
||||
def construct(self, img, decoder_input, decoder_hidden):
|
||||
'''
|
||||
get token output
|
||||
'''
|
||||
encoder_outputs = self.encoder(img)
|
||||
decoder_output, decoder_hidden, decoder_attention = self.decoder(
|
||||
decoder_input, decoder_hidden, encoder_outputs)
|
||||
return decoder_output, decoder_hidden, decoder_attention
|
||||
|
||||
|
||||
class AttentionOCR(nn.Cell):
|
||||
def __init__(self, batch_size, conv_out_dim, encoder_hidden_size, decoder_hidden_size,
|
||||
decoder_output_size, max_length, dropout_p=0.1):
|
||||
super(AttentionOCR, self).__init__()
|
||||
self.encoder = Encoder(batch_size=batch_size,
|
||||
conv_out_dim=conv_out_dim,
|
||||
hidden_size=encoder_hidden_size)
|
||||
self.decoder = Decoder(hidden_size=decoder_hidden_size,
|
||||
output_size=decoder_output_size,
|
||||
max_length=max_length,
|
||||
dropout_p=dropout_p)
|
||||
self.init_decoder_hidden = Tensor(np.zeros((1, batch_size, decoder_hidden_size),
|
||||
dtype=np.float16), mstype.float16)
|
||||
self.shape = P.Shape()
|
||||
self.split = P.Split(axis=1, output_num=max_length)
|
||||
self.concat = P.Concat()
|
||||
self.expand_dims = P.ExpandDims()
|
||||
self.argmax = P.Argmax()
|
||||
self.select = P.Select()
|
||||
|
||||
def construct(self, img, decoder_inputs, decoder_targets, teacher_force):
|
||||
encoder_outputs = self.encoder(img)
|
||||
_, text_len = self.shape(decoder_inputs)
|
||||
decoder_outputs = ()
|
||||
decoder_input_tuple = self.split(decoder_inputs)
|
||||
decoder_target_tuple = self.split(decoder_targets)
|
||||
decoder_input = decoder_input_tuple[0]
|
||||
decoder_hidden = self.init_decoder_hidden
|
||||
|
||||
for i in range(text_len):
|
||||
decoder_output, decoder_hidden, _ = self.decoder(decoder_input, decoder_hidden, encoder_outputs)
|
||||
topi = self.argmax(decoder_output)
|
||||
decoder_input_top = self.expand_dims(topi, 1)
|
||||
decoder_input = self.select(teacher_force, decoder_target_tuple[i], decoder_input_top)
|
||||
decoder_output = self.expand_dims(decoder_output, 0)
|
||||
decoder_outputs += (decoder_output,)
|
||||
outputs = self.concat(decoder_outputs)
|
||||
return outputs
|
||||
|
||||
|
||||
class AttentionOCRWithLossCell(nn.Cell):
|
||||
"""AttentionOCR with Loss"""
|
||||
def __init__(self, network, max_length):
|
||||
super(AttentionOCRWithLossCell, self).__init__()
|
||||
self.network = network
|
||||
self.loss = NLLLoss()
|
||||
self.shape = P.Shape()
|
||||
self.add = P.AddN()
|
||||
self.mean = P.ReduceMean()
|
||||
self.split = P.Split(axis=0, output_num=max_length)
|
||||
self.squeeze = P.Squeeze()
|
||||
self.cast = P.Cast()
|
||||
|
||||
def construct(self, img, decoder_inputs, decoder_targets, teacher_force):
|
||||
decoder_outputs = self.network(img, decoder_inputs, decoder_targets, teacher_force)
|
||||
decoder_outputs = self.cast(decoder_outputs, mstype.float32)
|
||||
_, text_len = self.shape(decoder_targets)
|
||||
loss_total = ()
|
||||
decoder_output_tuple = self.split(decoder_outputs)
|
||||
for i in range(text_len):
|
||||
loss = self.loss(self.squeeze(decoder_output_tuple[i]), decoder_targets[:, i])
|
||||
loss = self.mean(loss)
|
||||
loss_total += (loss,)
|
||||
loss_output = self.add(loss_total)
|
||||
return loss_output
|
||||
|
||||
|
||||
grad_scale = C.MultitypeFuncGraph("grad_scale")
|
||||
@grad_scale.register("Tensor", "Tensor")
|
||||
def tensor_grad_scale(scale, grad):
|
||||
return grad * P.Reciprocal()(scale)
|
||||
|
||||
|
||||
class TrainingWrapper(nn.Cell):
|
||||
def __init__(self, network, optimizer, sens=1.0):
|
||||
super(TrainingWrapper, self).__init__(auto_prefix=False)
|
||||
self.network = network
|
||||
self.network.set_grad()
|
||||
self.weights = ms.ParameterTuple(network.trainable_params())
|
||||
self.optimizer = optimizer
|
||||
self.grad = C.GradOperation(get_by_list=True, sens_param=True)
|
||||
self.sens = sens
|
||||
self.reducer_flag = False
|
||||
self.grad_reducer = None
|
||||
|
||||
# Set parallel_mode
|
||||
self.parallel_mode = context.get_auto_parallel_context("parallel_mode")
|
||||
if self.parallel_mode in [ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL]:
|
||||
self.reducer_flag = True
|
||||
if self.reducer_flag:
|
||||
mean = context.get_auto_parallel_context("gradients_mean")
|
||||
if auto_parallel_context().get_device_num_is_set():
|
||||
degree = context.get_auto_parallel_context("device_num")
|
||||
else:
|
||||
degree = get_group_size()
|
||||
self.grad_reducer = nn.DistributedGradReducer(optimizer.parameters, mean, degree)
|
||||
self.hyper_map = C.HyperMap()
|
||||
|
||||
def construct(self, *args):
|
||||
weights = self.weights
|
||||
loss = self.network(*args)
|
||||
sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens)
|
||||
grads = self.grad(self.network, weights)(*args, sens)
|
||||
if self.reducer_flag:
|
||||
grads = self.grad_reducer(grads)
|
||||
return F.depend(loss, self.optimizer(grads))
|
|
@ -0,0 +1,195 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#" ============================================================================
|
||||
"""
|
||||
CRN-Seq2Seq-OCR CNN model.
|
||||
|
||||
"""
|
||||
|
||||
import math
|
||||
import numpy as np
|
||||
|
||||
import mindspore.nn as nn
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.common.tensor import Tensor
|
||||
|
||||
|
||||
def calculate_gain(nonlinearity, param=None):
|
||||
"""calculate_gain"""
|
||||
linear_fns = ['linear', 'conv1d', 'conv2d', 'conv3d', 'conv_transpose1d', 'conv_transpose2d', 'conv_transpose3d']
|
||||
res = 0
|
||||
if nonlinearity in linear_fns or nonlinearity == 'sigmoid':
|
||||
res = 1
|
||||
elif nonlinearity == 'tanh':
|
||||
res = 5.0 / 3
|
||||
elif nonlinearity == 'relu':
|
||||
res = math.sqrt(2.0)
|
||||
elif nonlinearity == 'leaky_relu':
|
||||
if param is None:
|
||||
negative_slope = 0.01
|
||||
elif not isinstance(param, bool) and isinstance(param, int) or isinstance(param, float):
|
||||
negative_slope = param
|
||||
else:
|
||||
raise ValueError("negative_slope {} not a valid number".format(param))
|
||||
res = math.sqrt(2.0 / (1 + negative_slope ** 2))
|
||||
else:
|
||||
raise ValueError("Unsupported nonlinearity {}".format(nonlinearity))
|
||||
return res
|
||||
|
||||
|
||||
def _calculate_fan_in_and_fan_out(tensor):
|
||||
"""_calculate_fan_in_and_fan_out"""
|
||||
dimensions = len(tensor)
|
||||
if dimensions < 2:
|
||||
raise ValueError("Fan in and fan out can not be computed for tensor with fewer than 2 dimensions")
|
||||
if dimensions == 2:
|
||||
fan_in = tensor[1]
|
||||
fan_out = tensor[0]
|
||||
else:
|
||||
num_input_fmaps = tensor[1]
|
||||
num_output_fmaps = tensor[0]
|
||||
receptive_field_size = 1
|
||||
if dimensions > 2:
|
||||
receptive_field_size = tensor[2] * tensor[3]
|
||||
fan_in = num_input_fmaps * receptive_field_size
|
||||
fan_out = num_output_fmaps * receptive_field_size
|
||||
return fan_in, fan_out
|
||||
|
||||
|
||||
def _calculate_correct_fan(tensor, mode):
|
||||
mode = mode.lower()
|
||||
valid_modes = ['fan_in', 'fan_out']
|
||||
if mode not in valid_modes:
|
||||
raise ValueError("Mode {} not supported, please use one of {}".format(mode, valid_modes))
|
||||
fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
|
||||
return fan_in if mode == 'fan_in' else fan_out
|
||||
|
||||
|
||||
def kaiming_normal(inputs_shape, gain_param=0, mode='fan_in', nonlinearity='leaky_relu'):
|
||||
fan = _calculate_correct_fan(inputs_shape, mode)
|
||||
gain = calculate_gain(nonlinearity, gain_param)
|
||||
std = gain / math.sqrt(fan)
|
||||
return np.random.normal(0, std, size=inputs_shape).astype(np.float32)
|
||||
|
||||
|
||||
class ConvRelu(nn.Cell):
|
||||
"""
|
||||
Convolution Layer followed by Relu Layer
|
||||
|
||||
"""
|
||||
def __init__(self, in_channels, out_channels, kernel_size=1, stride=1):
|
||||
super(ConvRelu, self).__init__()
|
||||
shape = (out_channels, in_channels, kernel_size[0], kernel_size[1])
|
||||
self.conv = nn.Conv2d(in_channels,
|
||||
out_channels,
|
||||
kernel_size,
|
||||
stride,
|
||||
weight_init=Tensor(kaiming_normal(shape)))
|
||||
self.relu = nn.ReLU()
|
||||
|
||||
def construct(self, x):
|
||||
x = self.conv(x)
|
||||
x = self.relu(x)
|
||||
return x
|
||||
|
||||
|
||||
class ConvBNRelu(nn.Cell):
|
||||
"""
|
||||
Convolution Layer followed by Batch Normalization and Relu Layer
|
||||
|
||||
"""
|
||||
def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, pad_mode='same'):
|
||||
super(ConvBNRelu, self).__init__()
|
||||
shape = (out_channels, in_channels, kernel_size[0], kernel_size[1])
|
||||
self.conv = nn.Conv2d(in_channels,
|
||||
out_channels,
|
||||
kernel_size, stride,
|
||||
pad_mode=pad_mode,
|
||||
weight_init=Tensor(kaiming_normal(shape)))
|
||||
self.bn = nn.BatchNorm2d(out_channels)
|
||||
self.relu = nn.ReLU()
|
||||
|
||||
def construct(self, x):
|
||||
x = self.conv(x)
|
||||
x = self.bn(x)
|
||||
x = self.relu(x)
|
||||
return x
|
||||
|
||||
|
||||
class CNN(nn.Cell):
|
||||
"""
|
||||
CNN Class for OCR
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, conv_out_dim):
|
||||
super(CNN, self).__init__()
|
||||
self.convRelu1 = ConvRelu(3, 64, (3, 3))
|
||||
self.maxpool1 = nn.MaxPool2d(kernel_size=(2, 2), stride=(2, 2))
|
||||
|
||||
self.convRelu2 = ConvRelu(64, 128, (3, 3))
|
||||
self.maxpool2 = nn.MaxPool2d(kernel_size=(2, 1), stride=(2, 1))
|
||||
|
||||
self.convBNRelu1 = ConvBNRelu(128, 256, (3, 3))
|
||||
self.convRelu3 = ConvRelu(256, 256, (3, 3))
|
||||
self.maxpool3 = nn.MaxPool2d(kernel_size=(2, 1), stride=(2, 1))
|
||||
|
||||
self.convBNRelu2 = ConvBNRelu(256, 384, (3, 3))
|
||||
self.convRelu4 = ConvRelu(384, 384, (3, 3))
|
||||
self.maxpool4 = nn.MaxPool2d(kernel_size=(2, 1), stride=(2, 1))
|
||||
|
||||
self.convBNRelu3 = ConvBNRelu(384, 384, (3, 3))
|
||||
self.convRelu5 = ConvRelu(384, 384, (3, 3))
|
||||
self.maxpool5 = nn.MaxPool2d(kernel_size=(2, 1), stride=(2, 1))
|
||||
|
||||
self.convBNRelu4 = ConvBNRelu(384, 384, (3, 3))
|
||||
self.convRelu6 = ConvRelu(384, 384, (3, 3))
|
||||
self.maxpool6 = nn.MaxPool2d(kernel_size=(2, 1), stride=(2, 1))
|
||||
|
||||
self.pad = nn.Pad(paddings=((0, 0), (0, 0), (0, 0), (0, 1)))
|
||||
self.convBNRelu5 = ConvBNRelu(384, conv_out_dim, (2, 2), pad_mode='valid')
|
||||
self.dropout = nn.Dropout(keep_prob=0.5)
|
||||
|
||||
self.squeeze = P.Squeeze(2)
|
||||
self.cast = P.Cast()
|
||||
|
||||
def construct(self, x):
|
||||
x = self.convRelu1(x)
|
||||
x = self.maxpool1(x)
|
||||
|
||||
x = self.convRelu2(x)
|
||||
x = self.maxpool2(x)
|
||||
|
||||
x = self.convBNRelu1(x)
|
||||
x = self.convRelu3(x)
|
||||
x = self.maxpool3(x)
|
||||
|
||||
x = self.convBNRelu2(x)
|
||||
x = self.convRelu4(x)
|
||||
x = self.maxpool4(x)
|
||||
|
||||
x = self.convBNRelu3(x)
|
||||
x = self.convRelu5(x)
|
||||
x = self.maxpool5(x)
|
||||
|
||||
x = self.convBNRelu4(x)
|
||||
x = self.convRelu6(x)
|
||||
x = self.maxpool6(x)
|
||||
|
||||
x = self.pad(x)
|
||||
x = self.convBNRelu5(x)
|
||||
x = self.dropout(x)
|
||||
x = self.squeeze(x)
|
||||
|
||||
return x
|
|
@ -0,0 +1,61 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#" ============================================================================
|
||||
"""Config parameters for CRNN-Seq2Seq-OCR model."""
|
||||
|
||||
from easydict import EasyDict as ed
|
||||
|
||||
|
||||
config = ed({
|
||||
|
||||
# dataset-related
|
||||
"mindrecord_dir": "",
|
||||
"data_root": "",
|
||||
"annotation_file": "",
|
||||
|
||||
"val_data_root": "",
|
||||
"val_annotation_file": "",
|
||||
"data_json": "",
|
||||
|
||||
"characters_dictionary": {"pad_id": 0, "go_id": 1, "eos_id": 2, "unk_id": 3},
|
||||
"labels_not_use": [u'%#<23>?%', u'%#背景#%', u'%#不识<E4B88D>?%', u'#%不识<EFBFBD>?#', u'%#模糊#%', u'%#模糊#%'],
|
||||
"vocab_path": "./general_chars.txt",
|
||||
|
||||
#model-related
|
||||
"img_width": 512,
|
||||
"img_height": 128,
|
||||
"channel_size": 3,
|
||||
"conv_out_dim": 384,
|
||||
"encoder_hidden_size": 128,
|
||||
"decoder_hidden_size": 128,
|
||||
"decoder_output_size": 10000, # vocab_size is the decoder_output_size, characters_class+1, last 9999 is the space
|
||||
"dropout_p": 0.1,
|
||||
"max_length": 64,
|
||||
"attn_num_layers": 1,
|
||||
"teacher_force_ratio": 0.5,
|
||||
|
||||
#optimizer-related
|
||||
"lr": 0.0008,
|
||||
"adam_beta1": 0.5,
|
||||
"adam_beta2": 0.999,
|
||||
"loss_scale": 1024,
|
||||
|
||||
#train-related
|
||||
"batch_size": 32,
|
||||
"num_epochs": 20,
|
||||
"keep_checkpoint_max": 20,
|
||||
|
||||
#eval-related
|
||||
"eval_batch_size": 32
|
||||
})
|
|
@ -0,0 +1,245 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Create FSNS MindRecord files."""
|
||||
|
||||
import os
|
||||
import numpy as np
|
||||
|
||||
from mindspore.mindrecord import FileWriter
|
||||
|
||||
from src.config import config
|
||||
from src.utils import initialize_vocabulary
|
||||
|
||||
|
||||
def serialize_annotation(img_path, lex, vocab):
|
||||
|
||||
go_id = config.characters_dictionary.get("go_id")
|
||||
eos_id = config.characters_dictionary.get("eos_id")
|
||||
|
||||
word = [go_id]
|
||||
for special_label in config.labels_not_use:
|
||||
if lex == special_label:
|
||||
if config.print_no_train_label:
|
||||
print("label in for image: %s is special label, related label is: %s, skip ..." % (img_path, lex))
|
||||
return None
|
||||
|
||||
for c in lex:
|
||||
if c not in vocab:
|
||||
return None
|
||||
|
||||
c_idx = vocab.get(c)
|
||||
word.append(c_idx)
|
||||
|
||||
word.append(eos_id)
|
||||
word = np.array(word, dtype=np.int32)
|
||||
return word
|
||||
|
||||
def create_fsns_label(image_dir, anno_file_dirs):
|
||||
"""Get image path and annotation."""
|
||||
|
||||
if not os.path.isdir(image_dir):
|
||||
raise ValueError(f'Cannot find {image_dir} dataset path.')
|
||||
|
||||
image_files_dict = {}
|
||||
image_anno_dict = {}
|
||||
images = []
|
||||
img_id = 0
|
||||
|
||||
for anno_file_dir in anno_file_dirs:
|
||||
|
||||
anno_file = open(anno_file_dir, 'r').readlines()
|
||||
|
||||
for line in anno_file:
|
||||
|
||||
file_name = line.split('\t')[0]
|
||||
labels = line.split('\t')[1].split('\n')[0]
|
||||
image_path = os.path.join(image_dir, file_name)
|
||||
|
||||
if not os.path.isfile(image_path):
|
||||
print(f'Cannot find image {image_path} according to annotations.')
|
||||
continue
|
||||
|
||||
if labels:
|
||||
images.append(img_id)
|
||||
image_files_dict[img_id] = image_path
|
||||
image_anno_dict[img_id] = labels
|
||||
img_id += 1
|
||||
|
||||
return images, image_files_dict, image_anno_dict
|
||||
|
||||
|
||||
def fsns_train_data_to_mindrecord(mindrecord_dir, prefix="data_ocr.mindrecord", file_num=8):
|
||||
|
||||
anno_file_dirs = [config.train_annotation_file]
|
||||
images, image_path_dict, image_anno_dict = create_fsns_label(image_dir=config.data_root,
|
||||
anno_file_dirs=anno_file_dirs)
|
||||
vocab, _ = initialize_vocabulary(config.vocab_path)
|
||||
|
||||
data_schema = {"image": {"type": "bytes"},
|
||||
"label": {"type": "int32", "shape": [-1]},
|
||||
"decoder_input": {"type": "int32", "shape": [-1]},
|
||||
"decoder_mask": {"type": "int32", "shape": [-1]},
|
||||
"decoder_target": {"type": "int32", "shape": [-1]},
|
||||
"annotation": {"type": "string"}}
|
||||
|
||||
mindrecord_path = os.path.join(mindrecord_dir, prefix)
|
||||
|
||||
writer = FileWriter(mindrecord_path, file_num)
|
||||
writer.add_schema(data_schema, "ocr")
|
||||
|
||||
for img_id in images:
|
||||
|
||||
image_path = image_path_dict[img_id]
|
||||
annotation = image_anno_dict[img_id]
|
||||
|
||||
label_max_len = config.max_text_len
|
||||
text_max_len = config.max_text_len - 2
|
||||
|
||||
if len(annotation) > text_max_len:
|
||||
continue
|
||||
label = serialize_annotation(image_path, annotation, vocab)
|
||||
|
||||
if label is None:
|
||||
continue
|
||||
|
||||
label_len = len(label)
|
||||
decoder_input_len = label_max_len
|
||||
|
||||
if label_len <= decoder_input_len:
|
||||
label = np.concatenate((label, np.zeros(decoder_input_len - label_len, dtype=np.int32)))
|
||||
one_mask_len = label_len - config.go_shift
|
||||
target_weight = np.concatenate((np.ones(one_mask_len, dtype=np.float32),
|
||||
np.zeros(decoder_input_len - one_mask_len, dtype=np.float32)))
|
||||
else:
|
||||
continue
|
||||
|
||||
decoder_input = (np.array(label).T).astype(np.int32)
|
||||
target_weight = (np.array(target_weight).T).astype(np.int32)
|
||||
|
||||
if not len(decoder_input) == len(target_weight):
|
||||
continue
|
||||
|
||||
target = [decoder_input[i + 1] for i in range(len(decoder_input) - 1)]
|
||||
target = (np.array(target)).astype(np.int32)
|
||||
|
||||
|
||||
with open(image_path, 'rb') as f:
|
||||
img = f.read()
|
||||
|
||||
row = {"image": img,
|
||||
"label": label,
|
||||
"decoder_input": decoder_input,
|
||||
"decoder_mask": target_weight,
|
||||
"decoder_target": target,
|
||||
"annotation": str(annotation)}
|
||||
|
||||
writer.write_raw_data([row])
|
||||
writer.commit()
|
||||
|
||||
|
||||
def fsns_val_data_to_mindrecord(mindrecord_dir, prefix="data_ocr.mindrecord", file_num=8):
|
||||
|
||||
anno_file_dirs = [config.train_annotation_file]
|
||||
images, image_path_dict, image_anno_dict = create_fsns_label(image_dir=config.data_root,
|
||||
anno_file_dirs=anno_file_dirs)
|
||||
vocab, _ = initialize_vocabulary(config.vocab_path)
|
||||
|
||||
data_schema = {"image": {"type": "bytes"},
|
||||
"decoder_input": {"type": "int32", "shape": [-1]},
|
||||
"decoder_target": {"type": "int32", "shape": [-1]},
|
||||
"annotation": {"type": "string"}}
|
||||
|
||||
mindrecord_path = os.path.join(mindrecord_dir, prefix)
|
||||
|
||||
writer = FileWriter(mindrecord_path, file_num)
|
||||
writer.add_schema(data_schema, "ocr")
|
||||
|
||||
for img_id in images:
|
||||
|
||||
image_path = image_path_dict[img_id]
|
||||
annotation = image_anno_dict[img_id]
|
||||
|
||||
label_max_len = config.max_text_len
|
||||
text_max_len = config.max_text_len - 2
|
||||
|
||||
if len(annotation) > text_max_len:
|
||||
continue
|
||||
label = serialize_annotation(image_path, annotation, vocab)
|
||||
|
||||
if label is None:
|
||||
continue
|
||||
|
||||
label_len = len(label)
|
||||
decoder_input_len = label_max_len
|
||||
|
||||
if label_len <= decoder_input_len:
|
||||
label = np.concatenate((label, np.zeros(decoder_input_len - label_len, dtype=np.int32)))
|
||||
else:
|
||||
continue
|
||||
|
||||
decoder_input = (np.array(label).T).astype(np.int32)
|
||||
|
||||
target = [decoder_input[i + 1] for i in range(len(decoder_input) - 1)]
|
||||
target = (np.array(target)).astype(np.int32)
|
||||
|
||||
|
||||
with open(image_path, 'rb') as f:
|
||||
img = f.read()
|
||||
|
||||
row = {"image": img,
|
||||
"decoder_input": decoder_input,
|
||||
"decoder_target": target,
|
||||
"annotation": str(annotation)}
|
||||
|
||||
writer.write_raw_data([row])
|
||||
writer.commit()
|
||||
|
||||
def create_mindrecord(dataset="fsns", prefix="fsns.mindrecord", is_training=True):
|
||||
print("Start creating dataset!")
|
||||
if is_training:
|
||||
mindrecord_dir = os.path.join(config.mindrecord_dir, "train")
|
||||
mindrecord_files = [os.path.join(mindrecord_dir, prefix + "0")]
|
||||
|
||||
if not os.path.exists(mindrecord_files[0]):
|
||||
if not os.path.isdir(mindrecord_dir):
|
||||
os.makedirs(mindrecord_dir)
|
||||
if dataset == "fsns":
|
||||
if os.path.isdir(config.data_root):
|
||||
print("Create FSNS Mindrecord files for train pipeline.")
|
||||
fsns_train_data_to_mindrecord(mindrecord_dir=mindrecord_dir, prefix=prefix, file_num=8)
|
||||
print("Create FSNS Mindrecord files for train pipeline Done, at {}".format(mindrecord_dir))
|
||||
else:
|
||||
print("{} not exits!".format(config.data_root))
|
||||
else:
|
||||
print("{} dataset is not defined!".format(dataset))
|
||||
|
||||
if not is_training:
|
||||
mindrecord_dir = os.path.join(config.mindrecord_dir, "val")
|
||||
mindrecord_files = [os.path.join(mindrecord_dir, prefix + "0")]
|
||||
|
||||
if not os.path.exists(mindrecord_files[0]):
|
||||
if not os.path.isdir(mindrecord_dir):
|
||||
os.makedirs(mindrecord_dir)
|
||||
if dataset == "fsns":
|
||||
if os.path.isdir(config.val_data_root):
|
||||
print("Create FSNS Mindrecord files for val pipeline.")
|
||||
fsns_val_data_to_mindrecord(mindrecord_dir=mindrecord_dir, prefix=prefix)
|
||||
print("Create FSNS Mindrecord files for val pipeline Done, at {}".format(mindrecord_dir))
|
||||
else:
|
||||
print("{} not exits!".format(config.val_data_root))
|
||||
else:
|
||||
print("{} dataset is not defined!".format(dataset))
|
||||
|
||||
return mindrecord_files
|
|
@ -0,0 +1,144 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""FSNS dataset"""
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
import mindspore.dataset as de
|
||||
import mindspore.dataset.vision.c_transforms as C
|
||||
import mindspore.dataset.vision.py_transforms as P
|
||||
import mindspore.dataset.transforms.c_transforms as ops
|
||||
import mindspore.common.dtype as mstype
|
||||
|
||||
from src.config import config
|
||||
|
||||
|
||||
class AugmentationOps():
|
||||
def __init__(self, min_area_ratio=0.8, aspect_ratio_range=(0.8, 1.2), brightness=32./255.,
|
||||
contrast=0.5, saturation=0.5, hue=0.2, img_tile_shape=(150, 150)):
|
||||
self.min_area_ratio = min_area_ratio
|
||||
self.aspect_ratio_range = aspect_ratio_range
|
||||
self.img_tile_shape = img_tile_shape
|
||||
self.random_image_distortion_ops = P.RandomColorAdjust(brightness=brightness,
|
||||
contrast=contrast,
|
||||
saturation=saturation,
|
||||
hue=hue)
|
||||
|
||||
def __call__(self, img):
|
||||
img_h = self.img_tile_shape[0]
|
||||
img_w = self.img_tile_shape[1]
|
||||
img_new = np.zeros([128, 512, 3])
|
||||
|
||||
for i in range(4):
|
||||
img_tile = img[:, (i*150):((i+1)*150), :]
|
||||
# Random crop cut from the street sign image, resized to the same size.
|
||||
# Assures that the crop covers at least 0.8 area of the input image.
|
||||
# Aspect ratio of cropped image is within [0.8,1.2] range.
|
||||
h = img_h + 1
|
||||
w = img_w + 1
|
||||
|
||||
while (w >= img_w or h >= img_h):
|
||||
aspect_ratio = np.random.uniform(self.aspect_ratio_range[0],
|
||||
self.aspect_ratio_range[1])
|
||||
h_low = np.ceil(np.sqrt(self.min_area_ratio * img_h * img_w / aspect_ratio))
|
||||
h_high = np.floor(np.sqrt(img_h * img_w / aspect_ratio))
|
||||
h = np.random.randint(h_low, h_high)
|
||||
w = int(h * aspect_ratio)
|
||||
|
||||
y = np.random.randint(img_w - w)
|
||||
x = np.random.randint(img_h - h)
|
||||
img_tile = img_tile[x:(x+h), y:(y+w), :]
|
||||
# Randomly chooses one of the 4 interpolation resize methods.
|
||||
interpolation = np.random.choice([cv2.INTER_LINEAR,
|
||||
cv2.INTER_CUBIC,
|
||||
cv2.INTER_AREA,
|
||||
cv2.INTER_NEAREST])
|
||||
img_tile = cv2.resize(img_tile, (128, 128), interpolation=interpolation)
|
||||
# Random color distortion ops.
|
||||
img_tile_pil = Image.fromarray(img_tile)
|
||||
img_tile_pil = self.random_image_distortion_ops(img_tile_pil)
|
||||
img_tile = np.array(img_tile_pil)
|
||||
img_new[:, (i*128):((i+1)*128), :] = img_tile
|
||||
|
||||
img_new = 2 * (img_new / 255.) - 1
|
||||
return img_new
|
||||
|
||||
|
||||
class ImageResizeWithRescale():
|
||||
def __init__(self, standard_img_height, standard_img_width, channel_size=3):
|
||||
self.standard_img_height = standard_img_height
|
||||
self.standard_img_width = standard_img_width
|
||||
self.channel_size = channel_size
|
||||
|
||||
def __call__(self, img):
|
||||
img = cv2.resize(img, (self.standard_img_width, self.standard_img_height))
|
||||
img = 2 * (img / 255.) - 1
|
||||
return img
|
||||
|
||||
|
||||
def random_teacher_force(images, source_ids, target_ids):
|
||||
teacher_force = np.random.random() < config.teacher_force_ratio
|
||||
teacher_force_array = np.array([teacher_force], dtype=bool)
|
||||
return images, source_ids, target_ids, teacher_force_array
|
||||
|
||||
|
||||
def create_ocr_train_dataset(mindrecord_file, batch_size=32, rank_size=1, rank_id=0,
|
||||
is_training=True, num_parallel_workers=4, use_multiprocessing=True):
|
||||
ds = de.MindDataset(mindrecord_file,
|
||||
columns_list=["image", "decoder_input", "decoder_target"],
|
||||
num_shards=rank_size,
|
||||
shard_id=rank_id,
|
||||
num_parallel_workers=num_parallel_workers,
|
||||
shuffle=is_training)
|
||||
aug_ops = AugmentationOps()
|
||||
transforms = [C.Decode(),
|
||||
aug_ops,
|
||||
C.HWC2CHW()]
|
||||
ds = ds.map(operations=transforms, input_columns=["image"], python_multiprocessing=use_multiprocessing,
|
||||
num_parallel_workers=num_parallel_workers)
|
||||
ds = ds.map(operations=ops.PadEnd([config.max_length], 0), input_columns=["decoder_target"])
|
||||
ds = ds.map(operations=random_teacher_force, input_columns=["image", "decoder_input", "decoder_target"],
|
||||
output_columns=["image", "decoder_input", "decoder_target", "teacher_force"],
|
||||
column_order=["image", "decoder_input", "decoder_target", "teacher_force"])
|
||||
type_cast_op_bool = ops.TypeCast(mstype.bool_)
|
||||
ds = ds.map(operations=type_cast_op_bool, input_columns="teacher_force")
|
||||
print("Train dataset size= %s" % (int(ds.get_dataset_size())))
|
||||
ds = ds.batch(batch_size, drop_remainder=True)
|
||||
return ds
|
||||
|
||||
|
||||
def create_ocr_val_dataset(mindrecord_file, batch_size=32, rank_size=1, rank_id=0,
|
||||
num_parallel_workers=4, use_multiprocessing=True):
|
||||
ds = de.MindDataset(mindrecord_file,
|
||||
columns_list=["image", "annotation", "decoder_input", "decoder_target"],
|
||||
num_shards=rank_size,
|
||||
shard_id=rank_id,
|
||||
num_parallel_workers=num_parallel_workers,
|
||||
shuffle=False)
|
||||
resize_rescale_op = ImageResizeWithRescale(standard_img_height=128, standard_img_width=512)
|
||||
transforms = [C.Decode(),
|
||||
resize_rescale_op,
|
||||
C.HWC2CHW()]
|
||||
ds = ds.map(operations=transforms, input_columns=["image"], python_multiprocessing=use_multiprocessing,
|
||||
num_parallel_workers=num_parallel_workers)
|
||||
ds = ds.map(operations=ops.PadEnd([config.max_length], 0), input_columns=["decoder_target"],
|
||||
python_multiprocessing=use_multiprocessing, num_parallel_workers=8)
|
||||
ds = ds.map(operations=ops.PadEnd([config.max_length], 0), input_columns=["decoder_input"],
|
||||
python_multiprocessing=use_multiprocessing, num_parallel_workers=8)
|
||||
ds = ds.batch(batch_size, drop_remainder=True)
|
||||
print("Val dataset size= %s" % (str(int(ds.get_dataset_size())*batch_size)))
|
||||
return ds
|
|
@ -0,0 +1,55 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#" ============================================================================
|
||||
"""
|
||||
GRU cell
|
||||
"""
|
||||
import mindspore.nn as nn
|
||||
import mindspore.ops.operations as P
|
||||
import mindspore.common.dtype as mstype
|
||||
from src.weight_init import gru_default_state
|
||||
|
||||
|
||||
class GRU(nn.Cell):
|
||||
'''
|
||||
GRU model
|
||||
|
||||
Args:
|
||||
input_size: The number of expected features in the input
|
||||
hidden_size: The number of features in the hidden state
|
||||
'''
|
||||
def __init__(self, input_size, hidden_size):
|
||||
super(GRU, self).__init__()
|
||||
self.input_size = input_size
|
||||
self.hidden_size = hidden_size
|
||||
self.weight_i, self.weight_h, self.bias_i, self.bias_h = gru_default_state(self.input_size, self.hidden_size)
|
||||
self.rnn = P.DynamicGRUV2()
|
||||
self.cast = P.Cast()
|
||||
|
||||
def construct(self, x, h):
|
||||
'''
|
||||
GRU construction
|
||||
|
||||
Args:
|
||||
x(Tensor): GRU input
|
||||
h(Tensor): GRU hidden state
|
||||
|
||||
Returns:
|
||||
output(Tensor): rnn output
|
||||
hidden(Tensor): hidden state
|
||||
'''
|
||||
x = self.cast(x, mstype.float16)
|
||||
h = self.cast(h, mstype.float16)
|
||||
y1, h1, _, _, _, _ = self.rnn(x, self.weight_i, self.weight_h, self.bias_i, self.bias_h, None, h)
|
||||
return y1, h1
|
|
@ -0,0 +1,80 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Custom Logger."""
|
||||
import os
|
||||
import sys
|
||||
import logging
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
class LOGGER(logging.Logger):
|
||||
"""
|
||||
Logger.
|
||||
|
||||
Args:
|
||||
logger_name: String. Logger name.
|
||||
rank: Integer. Rank id.
|
||||
"""
|
||||
def __init__(self, logger_name, rank=0):
|
||||
super(LOGGER, self).__init__(logger_name)
|
||||
self.rank = rank
|
||||
if rank % 8 == 0:
|
||||
console = logging.StreamHandler(sys.stdout)
|
||||
console.setLevel(logging.INFO)
|
||||
formatter = logging.Formatter('%(asctime)s:%(levelname)s:%(message)s')
|
||||
console.setFormatter(formatter)
|
||||
self.addHandler(console)
|
||||
|
||||
def setup_logging_file(self, log_dir, rank=0):
|
||||
"""Setup logging file."""
|
||||
self.rank = rank
|
||||
if not os.path.exists(log_dir):
|
||||
os.makedirs(log_dir, exist_ok=True)
|
||||
log_name = datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S') + '_rank_{}.log'.format(rank)
|
||||
self.log_fn = os.path.join(log_dir, log_name)
|
||||
fh = logging.FileHandler(self.log_fn)
|
||||
fh.setLevel(logging.INFO)
|
||||
formatter = logging.Formatter('%(asctime)s:%(levelname)s:%(message)s')
|
||||
fh.setFormatter(formatter)
|
||||
self.addHandler(fh)
|
||||
|
||||
def info(self, msg, *args, **kwargs):
|
||||
if self.isEnabledFor(logging.INFO):
|
||||
self._log(logging.INFO, msg, args, **kwargs)
|
||||
|
||||
def save_args(self, args):
|
||||
self.info('Args:')
|
||||
args_dict = vars(args)
|
||||
for key in args_dict.keys():
|
||||
self.info('--> %s: %s', key, args_dict[key])
|
||||
self.info('')
|
||||
|
||||
def important_info(self, msg, *args, **kwargs):
|
||||
if self.isEnabledFor(logging.INFO) and self.rank == 0:
|
||||
line_width = 2
|
||||
important_msg = '\n'
|
||||
important_msg += ('*'*70 + '\n')*line_width
|
||||
important_msg += ('*'*line_width + '\n')*2
|
||||
important_msg += '*'*line_width + ' '*8 + msg + '\n'
|
||||
important_msg += ('*'*line_width + '\n')*2
|
||||
important_msg += ('*'*70 + '\n')*line_width
|
||||
self.info(important_msg, *args, **kwargs)
|
||||
|
||||
|
||||
def get_logger(path, rank):
|
||||
"""Get Logger."""
|
||||
logger = LOGGER('crnn-seq2seq-ocr', rank)
|
||||
logger.setup_logging_file(path, rank)
|
||||
return logger
|
|
@ -0,0 +1,196 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#" ============================================================================
|
||||
"""lstm"""
|
||||
import math
|
||||
import numpy as np
|
||||
from mindspore import nn, context, Tensor, Parameter, ParameterTuple
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore.common.initializer import initializer
|
||||
from mindspore.ops.primitive import constexpr
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
|
||||
@constexpr
|
||||
def _create_sequence_length(shape):
|
||||
num_step, batch_size, _ = shape
|
||||
sequence_length = Tensor(np.ones(batch_size, np.int32) * num_step, mstype.int32)
|
||||
return sequence_length
|
||||
|
||||
class LSTM(nn.Cell):
|
||||
"""
|
||||
Stacked LSTM (Long Short-Term Memory) layers.
|
||||
|
||||
Args:
|
||||
input_size (int): Number of features of input.
|
||||
hidden_size (int): Number of features of hidden layer.
|
||||
num_layers (int): Number of layers of stacked LSTM . Default: 1.
|
||||
has_bias (bool): Whether the cell has bias `b_ih` and `b_hh`. Default: True.
|
||||
batch_first (bool): Specifies whether the first dimension of input is batch_size. Default: False.
|
||||
dropout (float, int): If not 0, append `Dropout` layer on the outputs of each
|
||||
LSTM layer except the last layer. Default 0. The range of dropout is [0.0, 1.0].
|
||||
bidirectional (bool): Specifies whether it is a bidirectional LSTM. Default: False.
|
||||
|
||||
Inputs:
|
||||
- **input** (Tensor) - Tensor of shape (seq_len, batch_size, `input_size`) or
|
||||
(batch_size, seq_len, `input_size`).
|
||||
- **hx** (tuple) - A tuple of two Tensors (h_0, c_0) both of data type mindspore.float32 or
|
||||
mindspore.float16 and shape (num_directions * `num_layers`, batch_size, `hidden_size`).
|
||||
Data type of `hx` must be the same as `input`.
|
||||
|
||||
Outputs:
|
||||
Tuple, a tuple contains (`output`, (`h_n`, `c_n`)).
|
||||
|
||||
- **output** (Tensor) - Tensor of shape (seq_len, batch_size, num_directions * `hidden_size`).
|
||||
- **hx_n** (tuple) - A tuple of two Tensor (h_n, c_n) both of shape
|
||||
(num_directions * `num_layers`, batch_size, `hidden_size`).
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
input_size,
|
||||
hidden_size,
|
||||
num_layers=1,
|
||||
has_bias=True,
|
||||
batch_first=False,
|
||||
dropout=0,
|
||||
bidirectional=False):
|
||||
super(LSTM, self).__init__()
|
||||
self.is_ascend = context.get_context("device_target") == "Ascend"
|
||||
|
||||
self.batch_first = batch_first
|
||||
self.transpose = P.Transpose()
|
||||
self.num_layers = num_layers
|
||||
self.bidirectional = bidirectional
|
||||
self.dropout = dropout
|
||||
self.lstm = P.LSTM(input_size=input_size,
|
||||
hidden_size=hidden_size,
|
||||
num_layers=num_layers,
|
||||
has_bias=has_bias,
|
||||
bidirectional=bidirectional,
|
||||
dropout=float(dropout))
|
||||
|
||||
weight_size = 0
|
||||
gate_size = 4 * hidden_size
|
||||
stdv = 1 / math.sqrt(hidden_size)
|
||||
num_directions = 2 if bidirectional else 1
|
||||
if self.is_ascend:
|
||||
self.reverse_seq = P.ReverseSequence(batch_dim=1, seq_dim=0)
|
||||
self.concat = P.Concat(axis=0)
|
||||
self.concat_2dim = P.Concat(axis=2)
|
||||
self.cast = P.Cast()
|
||||
self.shape = P.Shape()
|
||||
if dropout < 0 or dropout > 1:
|
||||
raise ValueError("For LSTM, dropout must be a number in range [0, 1], but got {}".format(dropout))
|
||||
if dropout == 1:
|
||||
self.dropout_op = P.ZerosLike()
|
||||
else:
|
||||
self.dropout_op = nn.Dropout(float(1 - dropout))
|
||||
b0 = np.zeros(gate_size, dtype=np.float32)
|
||||
self.w_list = []
|
||||
self.b_list = []
|
||||
self.rnns_fw = P.DynamicRNN(forget_bias=0.0)
|
||||
self.rnns_bw = P.DynamicRNN(forget_bias=0.0)
|
||||
|
||||
for layer in range(num_layers):
|
||||
w_shape = input_size if layer == 0 else (num_directions * hidden_size)
|
||||
w_np = np.random.uniform(-stdv, stdv, (w_shape + hidden_size, gate_size)).astype(np.float32)
|
||||
self.w_list.append(Parameter(
|
||||
initializer(Tensor(w_np), [w_shape + hidden_size, gate_size]), name='weight_fw' + str(layer)))
|
||||
if has_bias:
|
||||
b_np = np.random.uniform(-stdv, stdv, gate_size).astype(np.float32)
|
||||
self.b_list.append(Parameter(initializer(Tensor(b_np), [gate_size]), name='bias_fw' + str(layer)))
|
||||
else:
|
||||
self.b_list.append(Parameter(initializer(Tensor(b0), [gate_size]), name='bias_fw' + str(layer)))
|
||||
if bidirectional:
|
||||
w_bw_np = np.random.uniform(-stdv, stdv, (w_shape + hidden_size, gate_size)).astype(np.float32)
|
||||
self.w_list.append(Parameter(initializer(Tensor(w_bw_np), [w_shape + hidden_size, gate_size]),
|
||||
name='weight_bw' + str(layer)))
|
||||
b_bw_np = np.random.uniform(-stdv, stdv, (4 * hidden_size)).astype(np.float32) if has_bias else b0
|
||||
self.b_list.append(Parameter(initializer(Tensor(b_bw_np), [gate_size]),
|
||||
name='bias_bw' + str(layer)))
|
||||
self.w_list = ParameterTuple(self.w_list)
|
||||
self.b_list = ParameterTuple(self.b_list)
|
||||
else:
|
||||
for layer in range(num_layers):
|
||||
input_layer_size = input_size if layer == 0 else hidden_size * num_directions
|
||||
increment_size = gate_size * input_layer_size
|
||||
increment_size += gate_size * hidden_size
|
||||
if has_bias:
|
||||
increment_size += 2 * gate_size
|
||||
weight_size += increment_size * num_directions
|
||||
w_np = np.random.uniform(-stdv, stdv, (weight_size, 1, 1)).astype(np.float32)
|
||||
self.weight = Parameter(initializer(Tensor(w_np), [weight_size, 1, 1]), name='weight')
|
||||
|
||||
def _stacked_bi_dynamic_rnn(self, x, init_h, init_c, weight, bias):
|
||||
"""stacked bidirectional dynamic_rnn"""
|
||||
x_shape = self.shape(x)
|
||||
sequence_length = _create_sequence_length(x_shape)
|
||||
pre_layer = x
|
||||
hn = ()
|
||||
cn = ()
|
||||
output = x
|
||||
for i in range(self.num_layers):
|
||||
offset = i * 2
|
||||
weight_fw, weight_bw = weight[offset], weight[offset + 1]
|
||||
bias_fw, bias_bw = bias[offset], bias[offset + 1]
|
||||
init_h_fw, init_h_bw = init_h[offset:offset + 1, :, :], init_h[offset + 1:offset + 2, :, :]
|
||||
init_c_fw, init_c_bw = init_c[offset:offset + 1, :, :], init_c[offset + 1:offset + 2, :, :]
|
||||
bw_x = self.reverse_seq(pre_layer, sequence_length)
|
||||
y, h, c, _, _, _, _, _ = self.rnns_fw(pre_layer, weight_fw, bias_fw, None, init_h_fw, init_c_fw)
|
||||
y_bw, h_bw, c_bw, _, _, _, _, _ = self.rnns_bw(bw_x, weight_bw, bias_bw, None, init_h_bw, init_c_bw)
|
||||
y_bw = self.reverse_seq(y_bw, sequence_length)
|
||||
output = self.concat_2dim((y, y_bw))
|
||||
pre_layer = self.dropout_op(output) if self.dropout else output
|
||||
hn += (h[-1:, :, :],)
|
||||
hn += (h_bw[-1:, :, :],)
|
||||
cn += (c[-1:, :, :],)
|
||||
cn += (c_bw[-1:, :, :],)
|
||||
status_h = self.concat(hn)
|
||||
status_c = self.concat(cn)
|
||||
return output, status_h, status_c
|
||||
|
||||
def _stacked_dynamic_rnn(self, x, init_h, init_c, weight, bias):
|
||||
"""stacked mutil_layer dynamic_rnn"""
|
||||
pre_layer = x
|
||||
hn = ()
|
||||
cn = ()
|
||||
y = 0
|
||||
for i in range(self.num_layers):
|
||||
weight_fw, bias_bw = weight[i], bias[i]
|
||||
init_h_fw, init_c_bw = init_h[i:i + 1, :, :], init_c[i:i + 1, :, :]
|
||||
y, h, c, _, _, _, _, _ = self.rnns_fw(pre_layer, weight_fw, bias_bw, None, init_h_fw, init_c_bw)
|
||||
pre_layer = self.dropout_op(y) if self.dropout else y
|
||||
hn += (h[-1:, :, :],)
|
||||
cn += (c[-1:, :, :],)
|
||||
status_h = self.concat(hn)
|
||||
status_c = self.concat(cn)
|
||||
return y, status_h, status_c
|
||||
|
||||
def construct(self, x, hx):
|
||||
if self.batch_first:
|
||||
x = self.transpose(x, (1, 0, 2))
|
||||
h, c = hx
|
||||
if self.is_ascend:
|
||||
x = self.cast(x, mstype.float16)
|
||||
h = self.cast(h, mstype.float16)
|
||||
c = self.cast(c, mstype.float16)
|
||||
if self.bidirectional:
|
||||
x, h, c = self._stacked_bi_dynamic_rnn(x, h, c, self.w_list, self.b_list)
|
||||
else:
|
||||
x, h, c = self._stacked_dynamic_rnn(x, h, c, self.w_list, self.b_list)
|
||||
else:
|
||||
x, h, c, _, _ = self.lstm(x, h, c, self.weight)
|
||||
if self.batch_first:
|
||||
x = self.transpose(x, (1, 0, 2))
|
||||
return x, (h, c)
|
|
@ -0,0 +1,165 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#" ============================================================================
|
||||
"""
|
||||
Seq2Seq_OCR model.
|
||||
|
||||
"""
|
||||
import mindspore.nn as nn
|
||||
import mindspore.ops.operations as P
|
||||
import mindspore.common.dtype as mstype
|
||||
|
||||
from src.cnn import CNN
|
||||
from src.gru import GRU
|
||||
from src.lstm import LSTM
|
||||
from src.weight_init import lstm_default_state
|
||||
|
||||
|
||||
class BidirectionalLSTM(nn.Cell):
|
||||
"""Bidirectional LSTM with a Dense layer
|
||||
|
||||
Args:
|
||||
batch_size(int): batch size of input data
|
||||
input_size(int): Size of time sequence
|
||||
hidden_size(int): the hidden size of LSTM layers
|
||||
output_size(int): the output size of the dense layer
|
||||
"""
|
||||
def __init__(self, batch_size, input_size, hidden_size, output_size):
|
||||
super(BidirectionalLSTM, self).__init__()
|
||||
self.rnn = LSTM(input_size=input_size, hidden_size=hidden_size, bidirectional=True).to_float(mstype.float16)
|
||||
self.h, self.c = lstm_default_state(batch_size, hidden_size, bidirectional=True)
|
||||
self.embedding = nn.Dense(hidden_size * 2, output_size).to_float(mstype.float16)
|
||||
self.shape = P.Shape()
|
||||
self.reshape = P.Reshape()
|
||||
self.cast = P.Cast()
|
||||
|
||||
def construct(self, inputs):
|
||||
inputs = self.cast(inputs, mstype.float16)
|
||||
recurrent, _ = self.rnn(inputs, (self.h, self.c))
|
||||
T, b, h = self.shape(recurrent)
|
||||
t_rec = self.reshape(recurrent, (T * b, h))
|
||||
output = self.embedding(t_rec)
|
||||
output = self.reshape(output, (T, b, -1))
|
||||
return output
|
||||
|
||||
|
||||
class AttnDecoderRNN(nn.Cell):
|
||||
"""Attention Decoder Structure with a one-layer GRU
|
||||
|
||||
Args:
|
||||
hidden_size(int): the hidden size
|
||||
output_size(int): the output size
|
||||
max_length(iht): max time step of the decoder
|
||||
dropout_p(float): dropout probability, default is 0.1
|
||||
"""
|
||||
def __init__(self, hidden_size, output_size, max_length, dropout_p=0.1):
|
||||
super(AttnDecoderRNN, self).__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.output_size = output_size
|
||||
self.dropout_p = dropout_p
|
||||
self.max_length = max_length
|
||||
self.embedding = nn.Embedding(self.output_size, self.hidden_size)
|
||||
self.attn = nn.Dense(in_channels=self.hidden_size * 2, out_channels=self.max_length).to_float(mstype.float16)
|
||||
self.attn_combine = nn.Dense(in_channels=self.hidden_size * 2,
|
||||
out_channels=self.hidden_size).to_float(mstype.float16)
|
||||
self.dropout = nn.Dropout(keep_prob=1.0 - self.dropout_p)
|
||||
self.gru = GRU(hidden_size, hidden_size).to_float(mstype.float16)
|
||||
self.out = nn.Dense(in_channels=self.hidden_size, out_channels=self.output_size).to_float(mstype.float16)
|
||||
self.transpose = P.Transpose()
|
||||
self.concat = P.Concat(axis=2)
|
||||
self.concat1 = P.Concat(axis=1)
|
||||
self.softmax = P.Softmax(axis=1)
|
||||
self.relu = P.ReLU()
|
||||
self.log_softmax = P.LogSoftmax(axis=1)
|
||||
self.bmm = P.BatchMatMul()
|
||||
self.unsqueeze = P.ExpandDims()
|
||||
self.squeeze = P.Squeeze(1)
|
||||
self.squeeze1 = P.Squeeze(0)
|
||||
self.cast = P.Cast()
|
||||
|
||||
def construct(self, inputs, hidden, encoder_outputs):
|
||||
embedded = self.embedding(inputs)
|
||||
embedded = self.transpose(embedded, (1, 0, 2))
|
||||
embedded = self.dropout(embedded)
|
||||
embedded = self.cast(embedded, mstype.float16)
|
||||
|
||||
embedded_concat = self.concat((embedded, hidden))
|
||||
embedded_concat = self.squeeze1(embedded_concat)
|
||||
attn_weights = self.softmax(self.attn(embedded_concat))
|
||||
attn_weights = self.unsqueeze(attn_weights, 1)
|
||||
perm_encoder_outputs = self.transpose(encoder_outputs, (1, 0, 2))
|
||||
attn_applied = self.bmm(attn_weights, perm_encoder_outputs)
|
||||
attn_applied = self.squeeze(attn_applied)
|
||||
embedded_squeeze = self.squeeze1(embedded)
|
||||
|
||||
output = self.concat1((embedded_squeeze, attn_applied))
|
||||
output = self.attn_combine(output)
|
||||
output = self.unsqueeze(output, 0)
|
||||
output = self.relu(output)
|
||||
|
||||
gru_hidden = self.squeeze1(hidden)
|
||||
output, hidden, _, _, _, _ = self.gru(output, gru_hidden)
|
||||
output = self.squeeze1(output)
|
||||
output = self.log_softmax(self.out(output))
|
||||
|
||||
return output, hidden, attn_weights
|
||||
|
||||
|
||||
class Encoder(nn.Cell):
|
||||
"""Encoder with a CNN and two BidirectionalLSTM layers
|
||||
|
||||
Args:
|
||||
batch_size(int): batch size of input data
|
||||
conv_out_dim(int): the output dimension of the cnn layer
|
||||
hidden_size(int): the hidden size of LSTM layers
|
||||
"""
|
||||
def __init__(self, batch_size, conv_out_dim, hidden_size):
|
||||
super(Encoder, self).__init__()
|
||||
self.cnn = CNN(int(conv_out_dim/4))
|
||||
self.lstm1 = BidirectionalLSTM(batch_size, conv_out_dim, hidden_size, hidden_size).to_float(mstype.float16)
|
||||
self.lstm2 = BidirectionalLSTM(batch_size, hidden_size, hidden_size, hidden_size).to_float(mstype.float16)
|
||||
self.transpose = P.Transpose()
|
||||
self.cast = P.Cast()
|
||||
self.split = P.Split(axis=3, output_num=4)
|
||||
self.concat = P.Concat(axis=1)
|
||||
|
||||
def construct(self, inputs):
|
||||
inputs = self.cast(inputs, mstype.float32)
|
||||
(x1, x2, x3, x4) = self.split(inputs)
|
||||
conv1 = self.cnn(x1)
|
||||
conv2 = self.cnn(x2)
|
||||
conv3 = self.cnn(x3)
|
||||
conv4 = self.cnn(x4)
|
||||
conv = self.concat((conv1, conv2, conv3, conv4))
|
||||
conv = self.transpose(conv, (2, 0, 1))
|
||||
output = self.lstm1(conv)
|
||||
output = self.lstm2(output)
|
||||
return output
|
||||
|
||||
|
||||
class Decoder(nn.Cell):
|
||||
"""Decoder
|
||||
|
||||
Args:
|
||||
hidden_size(int): the hidden size
|
||||
output_size(int): the output size
|
||||
max_length(iht): max time step of the decoder
|
||||
dropout_p(float): dropout probability, default is 0.1
|
||||
"""
|
||||
def __init__(self, hidden_size, output_size, max_length, dropout_p=0.1):
|
||||
super(Decoder, self).__init__()
|
||||
self.decoder = AttnDecoderRNN(hidden_size, output_size, max_length, dropout_p)
|
||||
|
||||
def construct(self, inputs, hidden, encoder_outputs):
|
||||
return self.decoder(inputs, hidden, encoder_outputs)
|
|
@ -0,0 +1,51 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Util class or function."""
|
||||
|
||||
from __future__ import absolute_import, division, print_function, unicode_literals
|
||||
|
||||
import os
|
||||
import codecs
|
||||
import logging
|
||||
|
||||
|
||||
def initialize_vocabulary(vocabulary_path):
|
||||
"""
|
||||
initialize vocabulary from file.
|
||||
assume the vocabulary is stored one-item-per-line
|
||||
"""
|
||||
characters_class = 9999
|
||||
|
||||
if os.path.exists(vocabulary_path):
|
||||
rev_vocab = []
|
||||
with codecs.open(vocabulary_path, 'r', encoding='utf-8') as voc_file:
|
||||
rev_vocab = [line.strip() for line in voc_file]
|
||||
|
||||
vocab = {x: y for (y, x) in enumerate(rev_vocab)}
|
||||
|
||||
reserved_char_size = characters_class - len(rev_vocab)
|
||||
if reserved_char_size < 0:
|
||||
raise ValueError("Number of characters in vocabulary is equal or larger than config.characters_class")
|
||||
|
||||
for _ in range(reserved_char_size):
|
||||
rev_vocab.append('')
|
||||
|
||||
# put space at the last position
|
||||
vocab[' '] = len(rev_vocab)
|
||||
rev_vocab.append(' ')
|
||||
logging.info("Initializing vocabulary ends: %s", vocabulary_path)
|
||||
return vocab, rev_vocab
|
||||
|
||||
raise ValueError("Initializing vocabulary ends: %s" % vocabulary_path)
|
|
@ -0,0 +1,41 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#" ============================================================================
|
||||
"""
|
||||
weights initialization
|
||||
"""
|
||||
import math
|
||||
import numpy as np
|
||||
from mindspore import Tensor, Parameter
|
||||
|
||||
|
||||
def lstm_default_state(batch_size, hidden_size, bidirectional, num_layers=1):
|
||||
"""init default input."""
|
||||
num_directions = 2 if bidirectional else 1
|
||||
h = Tensor(np.zeros((num_layers * num_directions, batch_size, hidden_size)).astype(np.float32))
|
||||
c = Tensor(np.zeros((num_layers * num_directions, batch_size, hidden_size)).astype(np.float32))
|
||||
return h, c
|
||||
|
||||
|
||||
def gru_default_state(input_size, hidden_size):
|
||||
stdv = 1 / math.sqrt(hidden_size)
|
||||
weight_i = Parameter(Tensor(np.random.uniform(-stdv, stdv, (input_size, 3*hidden_size)).astype(np.float32)),
|
||||
name='weight_i')
|
||||
weight_h = Parameter(Tensor(np.random.uniform(-stdv, stdv, (input_size, 3*hidden_size)).astype(np.float32)),
|
||||
name='weight_h')
|
||||
bias_i = Parameter(Tensor(np.random.uniform(-stdv, stdv, (3*hidden_size)).astype(np.float32)),
|
||||
name='bias_i')
|
||||
bias_h = Parameter(Tensor(np.random.uniform(-stdv, stdv, (3*hidden_size)).astype(np.float32)),
|
||||
name='bias_h')
|
||||
return weight_i, weight_h, bias_i, bias_h
|
|
@ -0,0 +1,158 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""
|
||||
CRNN-Seq2Seq-OCR train.
|
||||
|
||||
"""
|
||||
|
||||
import os
|
||||
import argparse
|
||||
import datetime
|
||||
|
||||
import mindspore.nn as nn
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore.train.model import Model
|
||||
from mindspore.context import ParallelMode
|
||||
from mindspore.common import set_seed
|
||||
from mindspore import Tensor
|
||||
from mindspore import context
|
||||
from mindspore.communication.management import init
|
||||
from mindspore.train.callback import ModelCheckpoint
|
||||
from mindspore.train.callback import CheckpointConfig, LossMonitor, TimeMonitor
|
||||
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
|
||||
from src.config import config
|
||||
from src.dataset import create_ocr_train_dataset
|
||||
from src.logger import get_logger
|
||||
from src.attention_ocr import AttentionOCR, AttentionOCRWithLossCell, TrainingWrapper
|
||||
|
||||
|
||||
set_seed(1)
|
||||
|
||||
|
||||
def parse_args():
|
||||
"""Parse train arguments."""
|
||||
parser = argparse.ArgumentParser('mindspore CRNN-Seq2Seq-OCR training')
|
||||
|
||||
# device related
|
||||
parser.add_argument("--device_target", type=str, default="Ascend",
|
||||
help="device where the code will be implemented.")
|
||||
parser.add_argument("--device_id", type=int, default=0, help="Device id, default: 0.")
|
||||
|
||||
# distributed related
|
||||
parser.add_argument('--is_distributed', type=int, default=0,
|
||||
help='Distribute train or not, 1 for yes, 0 for no. Default: 0')
|
||||
parser.add_argument('--rank_id', type=int, default=0, help='Local rank of distributed. Default: 0')
|
||||
parser.add_argument('--device_num', type=int, default=1, help='World size of device. Default: 1')
|
||||
|
||||
#dataset related
|
||||
parser.add_argument('--mindrecord_file', type=str, default='', help='Train dataset directory.')
|
||||
|
||||
# logging related
|
||||
parser.add_argument('--log_interval', type=int, default=100, help='Logging interval steps. Default: 100')
|
||||
parser.add_argument('--ckpt_path', type=str, default='outputs/', help='Checkpoint save location. Default: outputs/')
|
||||
parser.add_argument('--pre_checkpoint_path', type=str, default='', help='Checkpoint save location.')
|
||||
parser.add_argument('--ckpt_interval', type=int, default=None, help='Save checkpoint interval. Default: None')
|
||||
|
||||
parser.add_argument('--is_save_on_master', type=int, default=0,
|
||||
help='Save ckpt on master or all rank, 1 for master, 0 for all ranks. Default: 0')
|
||||
|
||||
args, _ = parser.parse_known_args()
|
||||
|
||||
# logger
|
||||
args.outputs_dir = os.path.join(args.ckpt_path,
|
||||
datetime.datetime.now().strftime('%Y-%m-%d_time_%H_%M_%S'))
|
||||
|
||||
return args
|
||||
|
||||
|
||||
def train():
|
||||
"""Train function."""
|
||||
args = parse_args()
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target, device_id=args.device_id)
|
||||
|
||||
if args.is_distributed:
|
||||
rank = args.rank_id
|
||||
device_num = args.device_num
|
||||
context.set_auto_parallel_context(device_num=device_num,
|
||||
parallel_mode=ParallelMode.DATA_PARALLEL,
|
||||
gradients_mean=True)
|
||||
init()
|
||||
else:
|
||||
rank = 0
|
||||
device_num = 1
|
||||
|
||||
# Logger
|
||||
args.logger = get_logger(args.outputs_dir, rank)
|
||||
args.rank_save_ckpt_flag = 0
|
||||
if args.is_save_on_master:
|
||||
if rank == 0:
|
||||
args.rank_save_ckpt_flag = 1
|
||||
else:
|
||||
args.rank_save_ckpt_flag = 1
|
||||
|
||||
# DATASET
|
||||
dataset = create_ocr_train_dataset(args.mindrecord_file,
|
||||
config.batch_size,
|
||||
rank_size=device_num,
|
||||
rank_id=rank)
|
||||
args.steps_per_epoch = dataset.get_dataset_size()
|
||||
args.logger.info('Finish loading dataset')
|
||||
|
||||
if not args.ckpt_interval:
|
||||
args.ckpt_interval = args.steps_per_epoch
|
||||
args.logger.save_args(args)
|
||||
|
||||
network = AttentionOCR(config.batch_size,
|
||||
int(config.img_width / 4),
|
||||
config.encoder_hidden_size,
|
||||
config.decoder_hidden_size,
|
||||
config.decoder_output_size,
|
||||
config.max_length,
|
||||
config.dropout_p)
|
||||
|
||||
if args.pre_checkpoint_path:
|
||||
param_dict = load_checkpoint(args.pre_checkpoint_path)
|
||||
load_param_into_net(network, param_dict)
|
||||
|
||||
network = AttentionOCRWithLossCell(network, config.max_length)
|
||||
|
||||
lr = Tensor(config.lr, mstype.float32)
|
||||
opt = nn.Adam(network.trainable_params(), lr, beta1=config.adam_beta1, beta2=config.adam_beta2,
|
||||
loss_scale=config.loss_scale)
|
||||
|
||||
network = TrainingWrapper(network, opt, sens=config.loss_scale)
|
||||
|
||||
args.logger.info('Finished get network')
|
||||
|
||||
callback = [TimeMonitor(data_size=1), LossMonitor()]
|
||||
if args.rank_save_ckpt_flag:
|
||||
ckpt_config = CheckpointConfig(save_checkpoint_steps=args.steps_per_epoch,
|
||||
keep_checkpoint_max=config.keep_checkpoint_max)
|
||||
save_ckpt_path = os.path.join(args.outputs_dir, 'ckpt_' + str(rank) + '/')
|
||||
ckpt_cb = ModelCheckpoint(config=ckpt_config,
|
||||
directory=save_ckpt_path,
|
||||
prefix="crnn_seq2seq_ocr")
|
||||
callback.append(ckpt_cb)
|
||||
|
||||
model = Model(network)
|
||||
model.train(config.num_epochs, dataset, callbacks=callback, dataset_sink_mode=False)
|
||||
|
||||
args.logger.info('==========Training Done===============')
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
train()
|
Loading…
Reference in New Issue