!6383 Add modelzoo CNNCTC Network.

Merge pull request !6383 from linqingke/fasterrcnn
This commit is contained in:
mindspore-ci-bot 2020-09-29 10:23:17 +08:00 committed by Gitee
commit 737e27d721
15 changed files with 1730 additions and 1 deletions

View File

@ -0,0 +1,354 @@
# Contents
- [CNNCTC Description](#CNNCTC-description)
- [Model Architecture](#model-architecture)
- [Dataset](#dataset)
- [Features](#features)
- [Mixed Precision](#mixed-precision)
- [Environment Requirements](#environment-requirements)
- [Quick Start](#quick-start)
- [Script Description](#script-description)
- [Script and Sample Code](#script-and-sample-code)
- [Script Parameters](#script-parameters)
- [Training Process](#training-process)
- [Training](#training)
- [Distributed Training](#distributed-training)
- [Evaluation Process](#evaluation-process)
- [Evaluation](#evaluation)
- [Model Description](#model-description)
- [Performance](#performance)
- [Evaluation Performance](#evaluation-performance)
- [Inference Performance](#evaluation-performance)
- [How to use](#how-to-use)
- [Inference](#inference)
- [Continue Training on the Pretrained Model](#continue-training-on-the-pretrained-model)
- [Transfer Learning](#transfer-learning)
- [Description of Random Situation](#description-of-random-situation)
- [ModelZoo Homepage](#modelzoo-homepage)
# [CNNCTC Description](#contents)
This paper proposes three major contributions to addresses scene text recognition (STR).
First, we examine the inconsistencies of training and evaluation datasets, and the performance gap results from inconsistencies.
Second, we introduce a unified four-stage STR framework that most existing STR models fit into.
Using this framework allows for the extensive evaluation of previously proposed STR modules and the discovery of previously
unexplored module combinations. Third, we analyze the module-wise contributions to performance in terms of accuracy, speed,
and memory demand, under one consistent set of training and evaluation datasets. Such analyses clean up the hindrance on the current
comparisons to understand the performance gain of the existing modules.
[Paper](https://arxiv.org/abs/1904.01906): J. Baek, G. Kim, J. Lee, S. Park, D. Han, S. Yun, S. J. Oh, and H. Lee, “What is wrong with scene text recognition model comparisons? dataset and model analysis,” ArXiv, vol. abs/1904.01906, 2019.
# [Model Architecture](#contents)
This is an example of training CNN+CTC model for text recognition on MJSynth and SynthText dataset with MindSpore.
# [Dataset](#contents)
The [MJSynth](https://www.robots.ox.ac.uk/~vgg/data/text/) and [SynthText](https://github.com/ankush-me/SynthText) dataset are used for model training. The [The IIIT 5K-word dataset](https://cvit.iiit.ac.in/research/projects/cvit-projects/the-iiit-5k-word-dataset) dataset is used for evaluation.
- step 1:
All the datasets have been preprocessed and stored in .lmdb format and can be downloaded [**HERE**](https://drive.google.com/drive/folders/192UfE9agQUMNq6AgU3_E05_FcPZK4hyt).
- step 2:
Uncompress the downloaded file, rename the MJSynth dataset as MJ, the SynthText dataset as ST and the IIIT dataset as IIIT.
- step 3:
Move above mentioned three datasets into `cnnctc_data` folder, and the structure should be as below:
```
|--- CNNCTC/
|--- cnnctc_data/
|--- ST/
data.mdb
lock.mdb
|--- MJ/
data.mdb
lock.mdb
|--- IIIT/
data.mdb
lock.mdb
......
```
- step 4:
Preprocess the dataset by running:
```
python src/preprocess_dataset.py
```
This takes around 75 minutes.
# [Features](#contents)
## Mixed Precision
The [mixed precision](https://www.mindspore.cn/tutorial/zh-CN/master/advanced_use/mixed_precision.html) training method accelerates the deep learning neural network training process by using both the single-precision and half-precision data formats, and maintains the network precision achieved by the single-precision training at the same time. Mixed precision training can accelerate the computation process, reduce memory usage, and enable a larger model or batch size to be trained on specific hardware.
For FP16 operators, if the input data type is FP32, the backend of MindSpore will automatically handle it with reduced precision. Users could check the reduced-precision operators by enabling INFO log and then searching reduce precision.
# [Environment Requirements](#contents)
- HardwareAscend
- Prepare hardware environment with Ascend processor. If you want to try Ascend , please send the [application form](https://obs-9be7.obs.cn-east-2.myhuaweicloud.com/file/other/Ascend%20Model%20Zoo%E4%BD%93%E9%AA%8C%E8%B5%84%E6%BA%90%E7%94%B3%E8%AF%B7%E8%A1%A8.docx) to ascend@huawei.com. Once approved, you can get the resources.
- Framework
- [MindSpore](https://www.mindspore.cn/install/en)
- For more information, please check the resources below
- [MindSpore tutorials](https://www.mindspore.cn/tutorial/zh-CN/master/index.html)
- [MindSpore API](https://www.mindspore.cn/api/zh-CN/master/index.html)
# [Quick Start](#contents)
- Install dependencies:
```
pip install lmdb
pip install Pillow
pip install tqdm
pip install six
```
- Standalone Training:
```
bash scripts/run_standalone_train_ascend.sh $PRETRAINED_CKPT
```
- Distributed Training:
```
bash scripts/run_distribute_train_ascend.sh $RANK_TABLE_FILE $PRETRAINED_CKPT
```
- Evaluation:
```
bash scripts/run_eval_ascend.sh $TRAINED_CKPT
```
# [Script Description](#contents)
## [Script and Sample Code](#contents)
The entire code structure is as following:
```
|--- CNNCTC/
|---README.md // descriptions about cnnctc
|---train.py // train scripts
|---eval.py // eval scripts
|---scripts
|---run_standalone_train_ascend.sh // shell script for standalone on ascend
|---run_distribute_train_ascend.sh // shell script for distributed on ascend
|---run_eval_ascend.sh // shell script for eval on ascend
|---src
|---__init__.py // init file
|---cnn_ctc.py // cnn_ctc network
|---config.py // total config
|---callback.py // loss callback file
|---dataset.py // process dataset
|---util.py // routine operation
|---generate_hccn_file.py // generate distribute json file
|---preprocess_dataset.py // preprocess dataset
```
## [Script Parameters](#contents)
Parameters for both training and evaluation can be set in `config.py`.
Arguments:
* `--CHARACTER`: Character labels.
* `--NUM_CLASS`: The number of classes including all character labels and the <blank> label for CTCLoss.
* `--HIDDEN_SIZE`: Model hidden size.
* `--FINAL_FEATURE_WIDTH`: The number of features.
* `--IMG_H` The height of input image.
* `--IMG_W` The width of input image.
* `--TRAIN_DATASET_PATH` The path to training dataset.
* `--TRAIN_DATASET_INDEX_PATH` The path to training dataset index file which determines the order .
* `--TRAIN_BATCH_SIZE` Training batch size. The batch size and index file must ensure input data is in fixed shape.
* `--TRAIN_DATASET_SIZE` Training dataset size.
* `--TEST_DATASET_PATH` The path to test dataset.
* `--TEST_BATCH_SIZE` Test batch size.
* `--TEST_DATASET_SIZE`Test dataset size.
* `--TRAIN_EPOCHS`Total training epochs.
* `--CKPT_PATH`The path to model checkpoint file, can be used to resume training and evaluation.
* `--SAVE_PATH`The path to save model checkpoint file.
* `--LR`Learning rate for standalone training.
* `--LR_PARA`Learning rate for distributed training.
* `--MOMENTUM`Momentum.
* `--LOSS_SCALE`Loss scale to prevent gradient underflow.
* `--SAVE_CKPT_PER_N_STEP`Save model checkpoint file per N steps.
* `--KEEP_CKPT_MAX_NUM`The maximum number of saved model checkpoint file.
## [Training Process](#contents)
### Training
- Standalone Training:
```
bash scripts/run_standalone_train_ascend.sh $PRETRAINED_CKPT
```
Results and checkpoints are written to `./train` folder. Log can be found in `./train/log` and loss values are recorded in `./train/loss.log`.
`$PRETRAINED_CKPT` is the path to model checkpoint and it is **optional**. If none is given the model will be trained from scratch.
- Distributed Training:
```
bash scripts/run_distribute_train_ascend.sh $RANK_TABLE_FILE $PRETRAINED_CKPT
```
Results and checkpoints are written to `./train_parallel_{i}` folder for device `i` respectively.
Log can be found in `./train_parallel_{i}/log_{i}.log` and loss values are recorded in `./train_parallel_{i}/loss.log`.
`$RANK_TABLE_FILE` is needed when you are running a distribute task on ascend.
`$PATH_TO_CHECKPOINT` is the path to model checkpoint and it is **optional**. If none is given the model will be trained from scratch.
### Training Result
Training result will be stored in the example path, whose folder name begins with "train" or "train_parallel". You can find checkpoint file together with result like the followings in loss.log.
```
# distribute training result(8p)
epoch: 1 step: 1 , loss is 76.25, average time per step is 0.335177839748392712
epoch: 1 step: 2 , loss is 73.46875, average time per step is 0.36798572540283203
epoch: 1 step: 3 , loss is 69.46875, average time per step is 0.3429678678512573
epoch: 1 step: 4 , loss is 64.3125, average time per step is 0.33512671788533527
epoch: 1 step: 5 , loss is 58.375, average time per step is 0.33149147033691406
epoch: 1 step: 6 , loss is 52.7265625, average time per step is 0.3292975425720215
...
epoch: 1 step: 8689 , loss is 9.706798802612482, average time per step is 0.3184656601312549
epoch: 1 step: 8690 , loss is 9.70612545289855, average time per step is 0.3184725407765116
epoch: 1 step: 8691 , loss is 9.70695776049204, average time per step is 0.31847309686135555
epoch: 1 step: 8692 , loss is 9.707279624277456, average time per step is 0.31847339290613375
epoch: 1 step: 8693 , loss is 9.70763437950938, average time per step is 0.3184720295013031
epoch: 1 step: 8694 , loss is 9.707695425072046, average time per step is 0.31847410284595573
epoch: 1 step: 8695 , loss is 9.708408273381295, average time per step is 0.31847338271072345
epoch: 1 step: 8696 , loss is 9.708703753591953, average time per step is 0.3184726025560777
epoch: 1 step: 8697 , loss is 9.709536406025824, average time per step is 0.31847212061114694
epoch: 1 step: 8698 , loss is 9.708542263610315, average time per step is 0.3184715309307257
```
## [Evaluation Process](#contents)
### Evaluation
- Evaluation:
```
bash scripts/run_eval_ascend.sh $TRAINED_CKPT
```
The model will be evaluated on the IIIT dataset, sample results and overall accuracy will be printed.
# [Model Description](#contents)
## [Performance](#contents)
### Training Performance
| Parameters | FasterRcnn |
| -------------------------- | ----------------------------------------------------------- |
| Model Version | V1 |
| Resource | Ascend 910 CPU 2.60GHz192coresMemory755G |
| uploaded Date | 09/28/2020 (month/day/year) |
| MindSpore Version | 1.0.0 |
| Dataset | MJSynth,SynthText |
| Training Parameters | epoch=3, batch_size=192 |
| Optimizer | RMSProp |
| Loss Function | CTCLoss |
| Speed | 1pc: 300 ms/step; 8pcs: 310 ms/step |
| Total time | 1pc: 18 hours; 8pcs: 2.3 hours |
| Parameters (M) | 177 |
| Scripts | https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/cnnctc |
### Evaluation Performance
| Parameters | FasterRcnn |
| ------------------- | --------------------------- |
| Model Version | V1 |
| Resource | Ascend 910 |
| Uploaded Date | 09/28/2020 (month/day/year) |
| MindSpore Version | 1.0.0 |
| Dataset | IIIT5K |
| batch_size | 192 |
| outputs | Accuracy |
| Accuracy | 85% |
| Model for inference | 675M (.ckpt file) |
## [How to use](#contents)
### Inference
If you need to use the trained model to perform inference on multiple hardware platforms, such as GPU, Ascend 910 or Ascend 310, you can refer to this [Link](https://www.mindspore.cn/tutorial/zh-CN/master/advanced_use/network_migration.html). Following the steps below, this is a simple example:
- Running on Ascend
```
# Set context
context.set_context(mode=context.GRAPH_HOME, device_target=cfg.device_target)
context.set_context(device_id=cfg.device_id)
# Load unseen dataset for inference
dataset = dataset.create_dataset(cfg.data_path, 1, False)
# Define model
net = CNNCTC(cfg.NUM_CLASS, cfg.HIDDEN_SIZE, cfg.FINAL_FEATURE_WIDTH)
opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 0.01,
cfg.momentum, weight_decay=cfg.weight_decay)
loss = P.CTCLoss(preprocess_collapse_repeated=False,
ctc_merge_repeated=True,
ignore_longer_outputs_than_inputs=False)
model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'})
# Load pre-trained model
param_dict = load_checkpoint(cfg.checkpoint_path)
load_param_into_net(net, param_dict)
net.set_train(False)
# Make predictions on the unseen dataset
acc = model.eval(dataset)
print("accuracy: ", acc)
```
### Continue Training on the Pretrained Model
- running on Ascend
```
# Load dataset
dataset = create_dataset(cfg.data_path, 1)
batch_num = dataset.get_dataset_size()
# Define model
net = CNNCTC(cfg.NUM_CLASS, cfg.HIDDEN_SIZE, cfg.FINAL_FEATURE_WIDTH)
# Continue training if set pre_trained to be True
if cfg.pre_trained:
param_dict = load_checkpoint(cfg.checkpoint_path)
load_param_into_net(net, param_dict)
lr = lr_steps(0, lr_max=cfg.lr_init, total_epochs=cfg.epoch_size,
steps_per_epoch=batch_num)
opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()),
Tensor(lr), cfg.momentum, weight_decay=cfg.weight_decay)
loss = P.CTCLoss(preprocess_collapse_repeated=False,
ctc_merge_repeated=True,
ignore_longer_outputs_than_inputs=False)
model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'},
amp_level="O2", keep_batchnorm_fp32=False, loss_scale_manager=None)
# Set callbacks
config_ck = CheckpointConfig(save_checkpoint_steps=batch_num * 5,
keep_checkpoint_max=cfg.keep_checkpoint_max)
time_cb = TimeMonitor(data_size=batch_num)
ckpoint_cb = ModelCheckpoint(prefix="train_googlenet_cifar10", directory="./",
config=config_ck)
loss_cb = LossMonitor()
# Start training
model.train(cfg.epoch_size, dataset, callbacks=[time_cb, ckpoint_cb, loss_cb])
print("train success")
```
# [ModelZoo Homepage](#contents)
Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo).

View File

@ -0,0 +1,109 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""cnnctc eval"""
import argparse
import time
import numpy as np
from mindspore import Tensor, context
import mindspore.common.dtype as mstype
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.dataset import GeneratorDataset
from src.util import CTCLabelConverter, AverageMeter
from src.config import Config_CNNCTC
from src.dataset import IIIT_Generator_batch
from src.cnn_ctc import CNNCTC_Model
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False,
save_graphs_path=".", enable_auto_mixed_precision=False)
def test_dataset_creator():
ds = GeneratorDataset(IIIT_Generator_batch, ['img', 'label_indices', 'text', 'sequence_length', 'label_str'])
return ds
def test(config):
ds = test_dataset_creator()
net = CNNCTC_Model(config.NUM_CLASS, config.HIDDEN_SIZE, config.FINAL_FEATURE_WIDTH)
ckpt_path = config.CKPT_PATH
param_dict = load_checkpoint(ckpt_path)
load_param_into_net(net, param_dict)
print('parameters loaded! from: ', ckpt_path)
converter = CTCLabelConverter(config.CHARACTER)
model_run_time = AverageMeter()
npu_to_cpu_time = AverageMeter()
postprocess_time = AverageMeter()
count = 0
correct_count = 0
for data in ds.create_tuple_iterator():
img, _, text, _, length = data
img_tensor = Tensor(img, mstype.float32)
model_run_begin = time.time()
model_predict = net(img_tensor)
model_run_end = time.time()
model_run_time.update(model_run_end - model_run_begin)
npu_to_cpu_begin = time.time()
model_predict = np.squeeze(model_predict.asnumpy())
npu_to_cpu_end = time.time()
npu_to_cpu_time.update(npu_to_cpu_end - npu_to_cpu_begin)
postprocess_begin = time.time()
preds_size = np.array([model_predict.shape[1]] * config.TEST_BATCH_SIZE)
preds_index = np.argmax(model_predict, 2)
preds_index = np.reshape(preds_index, [-1])
preds_str = converter.decode(preds_index, preds_size)
postprocess_end = time.time()
postprocess_time.update(postprocess_end - postprocess_begin)
label_str = converter.reverse_encode(text.asnumpy(), length.asnumpy())
if count == 0:
model_run_time.reset()
npu_to_cpu_time.reset()
postprocess_time.reset()
else:
print('---------model run time--------', model_run_time.avg)
print('---------npu_to_cpu run time--------', npu_to_cpu_time.avg)
print('---------postprocess run time--------', postprocess_time.avg)
print("Prediction samples: \n", preds_str[:5])
print("Ground truth: \n", label_str[:5])
for pred, label in zip(preds_str, label_str):
if pred == label:
correct_count += 1
count += 1
print('accuracy: ', correct_count / count)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description="FasterRcnn training")
parser.add_argument('--device_id', type=int, default=0, help="Device id, default is 0.")
parser.add_argument("--ckpt_path", type=str, default="", help="trained file path.")
args_opt = parser.parse_args()
cfg = Config_CNNCTC()
if args_opt.ckpt_path != "":
cfg.CKPT_PATH = args_opt.ckpt_path
test(cfg)

View File

@ -0,0 +1,57 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
current_exec_path=$(pwd)
echo ${current_exec_path}
get_real_path(){
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
PATH1=$(get_real_path $1)
echo $PATH1
PATH2=$(get_real_path $2)
echo $PATH2
python ${current_exec_path}/src/generate_hccn_file.py --rank_file=$PATH1
export RANK_TABLE_FILE=$PATH1
export RANK_SIZE=8
ulimit -u unlimited
for((i=0;i<$RANK_SIZE;i++));
do
rm ./train_parallel_$i/ -rf
mkdir ./train_parallel_$i
cp ./*.py ./train_parallel_$i
cp ./scripts/*.sh ./train_parallel_$i
cp -r ./src ./train_parallel_$i
cd ./train_parallel_$i || exit
export RANK_ID=$i
export DEVICE_ID=$i
echo "start training for rank $RANK_ID, device $DEVICE_ID"
if [ -f $PATH2 ]
then
python train.py --device_id=$i --ckpt_path=$PATH2 --run_distribute=True >log_$i.log 2>&1 &
else
python train.py --device_id=$i --run_distribute=True >log_$i.log 2>&1 &
fi
cd .. || exit
done

View File

@ -0,0 +1,54 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# -ne 1 ]
then
echo "Usage: sh run_eval_ascend.sh [TRAINED_CKPT]"
exit 1
fi
get_real_path(){
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
PATH1=$(get_real_path $1)
echo $PATH1
if [ ! -f $PATH1 ]
then
echo "error: TRAINED_CKPT=$PATH1 is not a file"
exit 1
fi
ulimit -u unlimited
export DEVICE_ID=0
if [ -d "eval" ];
then
rm -rf ./eval
fi
mkdir ./eval
cp ./*.py ./eval
cp ./scripts/*.sh ./eval
cp -r ./src ./eval
cd ./eval || exit
echo "start infering for device $DEVICE_ID"
env > env.log
python eval.py --device_id=$DEVICE_ID --ckpt_path=$PATH1 &> log &
cd .. || exit

View File

@ -0,0 +1,45 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
get_real_path(){
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
PATH1=$(get_real_path $1)
ulimit -u unlimited
if [ -d "train" ];
then
rm -rf ./train
fi
mkdir ./train
cp ./*.py ./train
cp ./scripts/*.sh ./train
cp -r ./src ./train
cd ./train || exit
echo "start training for device $DEVICE_ID"
env > env.log
if [ -f $PATH1 ]
then
python train.py --device_id=$DEVICE_ID --ckpt_path=$PATH1 --run_distribute=False &> log &
else
python train.py --device_id=$DEVICE_ID --run_distribute=False &> log &
fi
cd .. || exit

View File

@ -0,0 +1,15 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""src init file"""

View File

@ -0,0 +1,71 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""loss callback"""
import time
from mindspore.train.callback import Callback
from .util import AverageMeter
class LossCallBack(Callback):
"""
Monitor the loss in training.
If the loss is NAN or INF terminating training.
Note:
If per_print_times is 0 do not print loss.
Args:
per_print_times (int): Print loss every times. Default: 1.
"""
def __init__(self, per_print_times=1):
super(LossCallBack, self).__init__()
if not isinstance(per_print_times, int) or per_print_times < 0:
raise ValueError("print_step must be int and >= 0.")
self._per_print_times = per_print_times
self.loss_avg = AverageMeter()
self.timer = AverageMeter()
self.start_time = time.time()
def step_end(self, run_context):
cb_params = run_context.original_args()
loss = cb_params.net_outputs.asnumpy()
cur_step_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num + 1
cur_num = cb_params.cur_step_num
if cur_step_in_epoch % 2000 == 1:
self.loss_avg = AverageMeter()
self.timer = AverageMeter()
self.start_time = time.time()
else:
self.timer.update(time.time() - self.start_time)
self.start_time = time.time()
self.loss_avg.update(loss)
if self._per_print_times != 0 and cur_num % self._per_print_times == 0:
loss_file = open("./loss.log", "a+")
loss_file.write("epoch: %s step: %s , loss is %s, average time per step is %s" % (
cb_params.cur_epoch_num, cur_step_in_epoch,
self.loss_avg.avg, self.timer.avg))
loss_file.write("\n")
loss_file.close()
print("epoch: %s step: %s , loss is %s, average time per step is %s" % (
cb_params.cur_epoch_num, cur_step_in_epoch,
self.loss_avg.avg, self.timer.avg))

View File

@ -0,0 +1,255 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""cnn_ctc network define"""
import mindspore.nn as nn
from mindspore.ops import operations as P
from mindspore.common.initializer import TruncatedNormal, initializer
import mindspore.common.dtype as mstype
class CNNCTC_Model(nn.Cell):
def __init__(self, num_class, hidden_size, final_feature_width):
super(CNNCTC_Model, self).__init__()
self.num_class = num_class
self.hidden_size = hidden_size
self.final_feature_width = final_feature_width
self.FeatureExtraction = ResNet_FeatureExtractor()
self.Prediction = nn.Dense(self.hidden_size, self.num_class)
self.transpose = P.Transpose()
self.reshape = P.Reshape()
def construct(self, x):
x = self.FeatureExtraction(x)
x = self.transpose(x, (0, 3, 1, 2)) # [b, c, h, w] -> [b, w, c, h]
x = self.reshape(x, (-1, self.hidden_size))
x = self.Prediction(x)
x = self.reshape(x, (-1, self.final_feature_width, self.num_class))
return x
class WithLossCell(nn.Cell):
def __init__(self, backbone, loss_fn):
super(WithLossCell, self).__init__(auto_prefix=False)
self._backbone = backbone
self._loss_fn = loss_fn
def construct(self, img, label_indices, text, sequence_length):
model_predict = self._backbone(img)
return self._loss_fn(model_predict, label_indices, text, sequence_length)
@property
def backbone_network(self):
return self._backbone
class ctc_loss(nn.Cell):
def __init__(self):
super(ctc_loss, self).__init__()
self.loss = P.CTCLoss(preprocess_collapse_repeated=False,
ctc_merge_repeated=True,
ignore_longer_outputs_than_inputs=False)
self.mean = P.ReduceMean()
self.transpose = P.Transpose()
self.reshape = P.Reshape()
def construct(self, inputs, labels_indices, labels_values, sequence_length):
inputs = self.transpose(inputs, (1, 0, 2))
loss, _ = self.loss(inputs, labels_indices, labels_values, sequence_length)
loss = self.mean(loss)
return loss
class ResNet_FeatureExtractor(nn.Cell):
def __init__(self):
super(ResNet_FeatureExtractor, self).__init__()
self.ConvNet = ResNet(3, 512, BasicBlock, [1, 2, 5, 3])
def construct(self, featuremap):
return self.ConvNet(featuremap)
class ResNet(nn.Cell):
def __init__(self, input_channel, output_channel, block, layers):
super(ResNet, self).__init__()
self.output_channel_block = [int(output_channel / 4), int(output_channel / 2), output_channel, output_channel]
self.inplanes = int(output_channel / 8)
self.conv0_1 = ms_conv3x3(input_channel, int(output_channel / 16), stride=1, padding=1, pad_mode='pad')
self.bn0_1 = ms_fused_bn(int(output_channel / 16))
self.conv0_2 = ms_conv3x3(int(output_channel / 16), self.inplanes, stride=1, padding=1, pad_mode='pad')
self.bn0_2 = ms_fused_bn(self.inplanes)
self.relu = P.ReLU()
self.maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2, pad_mode='valid')
self.layer1 = self._make_layer(block, self.output_channel_block[0], layers[0])
self.conv1 = ms_conv3x3(self.output_channel_block[0], self.output_channel_block[0], stride=1, padding=1,
pad_mode='pad')
self.bn1 = ms_fused_bn(self.output_channel_block[0])
self.maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2, pad_mode='valid')
self.layer2 = self._make_layer(block, self.output_channel_block[1], layers[1])
self.conv2 = ms_conv3x3(self.output_channel_block[1], self.output_channel_block[1], stride=1, padding=1,
pad_mode='pad')
self.bn2 = ms_fused_bn(self.output_channel_block[1])
self.pad = P.Pad(((0, 0), (0, 0), (0, 0), (1, 1)))
self.maxpool3 = nn.MaxPool2d(kernel_size=2, stride=(2, 1), pad_mode='valid')
self.layer3 = self._make_layer(block, self.output_channel_block[2], layers[2])
self.conv3 = ms_conv3x3(self.output_channel_block[2], self.output_channel_block[2], stride=1, padding=1,
pad_mode='pad')
self.bn3 = ms_fused_bn(self.output_channel_block[2])
self.layer4 = self._make_layer(block, self.output_channel_block[3], layers[3])
self.conv4_1 = ms_conv2x2(self.output_channel_block[3], self.output_channel_block[3], stride=(2, 1),
pad_mode='valid')
self.bn4_1 = ms_fused_bn(self.output_channel_block[3])
self.conv4_2 = ms_conv2x2(self.output_channel_block[3], self.output_channel_block[3], stride=1, padding=0,
pad_mode='valid')
self.bn4_2 = ms_fused_bn(self.output_channel_block[3])
def _make_layer(self, block, planes, blocks, stride=1):
downsample = None
if stride != 1 or self.inplanes != planes * block.expansion:
downsample = nn.SequentialCell(
[ms_conv1x1(self.inplanes, planes * block.expansion, stride=stride),
ms_fused_bn(planes * block.expansion)]
)
layers = []
layers.append(block(self.inplanes, planes, stride, downsample))
self.inplanes = planes * block.expansion
for _ in range(1, blocks):
layers.append(block(self.inplanes, planes))
return nn.SequentialCell(layers)
def construct(self, x):
x = self.conv0_1(x)
x = self.bn0_1(x)
x = self.relu(x)
x = self.conv0_2(x)
x = self.bn0_2(x)
x = self.relu(x)
x = self.maxpool1(x)
x = self.layer1(x)
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.maxpool2(x)
x = self.layer2(x)
x = self.conv2(x)
x = self.bn2(x)
x = self.relu(x)
x = self.pad(x)
x = self.maxpool3(x)
x = self.layer3(x)
x = self.conv3(x)
x = self.bn3(x)
x = self.relu(x)
x = self.layer4(x)
x = self.pad(x)
x = self.conv4_1(x)
x = self.bn4_1(x)
x = self.relu(x)
x = self.conv4_2(x)
x = self.bn4_2(x)
x = self.relu(x)
return x
class BasicBlock(nn.Cell):
expansion = 1
def __init__(self, inplanes, planes, stride=1, downsample=None):
super(BasicBlock, self).__init__()
self.conv1 = ms_conv3x3(inplanes, planes, stride=stride, padding=1, pad_mode='pad')
self.bn1 = ms_fused_bn(planes)
self.conv2 = ms_conv3x3(planes, planes, stride=stride, padding=1, pad_mode='pad')
self.bn2 = ms_fused_bn(planes)
self.relu = P.ReLU()
self.downsample = downsample
self.add = P.TensorAdd()
def construct(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
out = self.conv2(out)
out = self.bn2(out)
if self.downsample is not None:
residual = self.downsample(x)
out = self.add(out, residual)
out = self.relu(out)
return out
def weight_variable(shape, factor=0.1, half_precision=False):
if half_precision:
return initializer(TruncatedNormal(0.02), shape, dtype=mstype.float16)
return TruncatedNormal(0.02)
def ms_conv3x3(in_channels, out_channels, stride=1, padding=0, pad_mode='same', has_bias=False):
"""Get a conv2d layer with 3x3 kernel size."""
init_value = weight_variable((out_channels, in_channels, 3, 3))
return nn.Conv2d(in_channels, out_channels,
kernel_size=3, stride=stride, padding=padding, pad_mode=pad_mode, weight_init=init_value,
has_bias=has_bias)
def ms_conv1x1(in_channels, out_channels, stride=1, padding=0, pad_mode='same', has_bias=False):
"""Get a conv2d layer with 1x1 kernel size."""
init_value = weight_variable((out_channels, in_channels, 1, 1))
return nn.Conv2d(in_channels, out_channels,
kernel_size=1, stride=stride, padding=padding, pad_mode=pad_mode, weight_init=init_value,
has_bias=has_bias)
def ms_conv2x2(in_channels, out_channels, stride=1, padding=0, pad_mode='same', has_bias=False):
"""Get a conv2d layer with 2x2 kernel size."""
init_value = weight_variable((out_channels, in_channels, 1, 1))
return nn.Conv2d(in_channels, out_channels,
kernel_size=2, stride=stride, padding=padding, pad_mode=pad_mode, weight_init=init_value,
has_bias=has_bias)
def ms_fused_bn(channels, momentum=0.1):
"""Get a fused batchnorm"""
return nn.BatchNorm2d(channels, momentum=momentum)

View File

@ -0,0 +1,43 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""network config setting, will be used in train.py and eval.py"""
class Config_CNNCTC():
# model config
CHARACTER = '0123456789abcdefghijklmnopqrstuvwxyz'
NUM_CLASS = len(CHARACTER) + 1
HIDDEN_SIZE = 512
FINAL_FEATURE_WIDTH = 26
# dataset config
IMG_H = 32
IMG_W = 100
TRAIN_DATASET_PATH = 'CNNCTC_Data/ST_MJ/'
TRAIN_DATASET_INDEX_PATH = 'CNNCTC_Data/st_mj_fixed_length_index_list.pkl'
TRAIN_BATCH_SIZE = 192
TEST_DATASET_PATH = 'CNNCTC_Data/IIIT5k_3000'
TEST_BATCH_SIZE = 256
TEST_DATASET_SIZE = 2976
TRAIN_EPOCHS = 3
# training config
CKPT_PATH = ''
SAVE_PATH = './'
LR = 1e-4
LR_PARA = 5e-4
MOMENTUM = 0.8
LOSS_SCALE = 8096
SAVE_CKPT_PER_N_STEP = 2000
KEEP_CKPT_MAX_NUM = 5

View File

@ -0,0 +1,265 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""cnn_ctc dataset"""
import sys
import pickle
import math
import six
import numpy as np
from PIL import Image
import lmdb
from mindspore.communication.management import get_rank, get_group_size
from .util import CTCLabelConverter
from .config import Config_CNNCTC
config = Config_CNNCTC()
class NormalizePAD():
def __init__(self, max_size, PAD_type='right'):
self.max_size = max_size
self.PAD_type = PAD_type
def __call__(self, img):
# toTensor
img = np.array(img, dtype=np.float32)
img = img.transpose([2, 0, 1])
img = img.astype(np.float)
img = np.true_divide(img, 255)
# normalize
img = np.subtract(img, 0.5)
img = np.true_divide(img, 0.5)
_, _, w = img.shape
Pad_img = np.zeros(shape=self.max_size, dtype=np.float32)
Pad_img[:, :, :w] = img # right pad
if self.max_size[2] != w: # add border Pad
Pad_img[:, :, w:] = np.tile(np.expand_dims(img[:, :, w - 1], 2), (1, 1, self.max_size[2] - w))
return Pad_img
class AlignCollate():
def __init__(self, imgH=32, imgW=100):
self.imgH = imgH
self.imgW = imgW
def __call__(self, images):
resized_max_w = self.imgW
input_channel = 3
transform = NormalizePAD((input_channel, self.imgH, resized_max_w))
resized_images = []
for image in images:
w, h = image.size
ratio = w / float(h)
if math.ceil(self.imgH * ratio) > self.imgW:
resized_w = self.imgW
else:
resized_w = math.ceil(self.imgH * ratio)
resized_image = image.resize((resized_w, self.imgH), Image.BICUBIC)
resized_images.append(transform(resized_image))
image_tensors = np.concatenate([np.expand_dims(t, 0) for t in resized_images], 0)
return image_tensors
def get_img_from_lmdb(env, index):
with env.begin(write=False) as txn:
label_key = 'label-%09d'.encode() % index
label = txn.get(label_key).decode('utf-8')
img_key = 'image-%09d'.encode() % index
imgbuf = txn.get(img_key)
buf = six.BytesIO()
buf.write(imgbuf)
buf.seek(0)
try:
img = Image.open(buf).convert('RGB') # for color image
except IOError:
print(f'Corrupted image for {index}')
# make dummy image and dummy label for corrupted image.
img = Image.new('RGB', (config.IMG_W, config.IMG_H))
label = '[dummy_label]'
label = label.lower()
return img, label
class ST_MJ_Generator_batch_fixed_length:
def __init__(self):
self.align_collector = AlignCollate()
self.converter = CTCLabelConverter(config.CHARACTER)
self.env = lmdb.open(config.TRAIN_DATASET_PATH, max_readers=32, readonly=True, lock=False, readahead=False,
meminit=False)
if not self.env:
print('cannot create lmdb from %s' % (config.TRAIN_DATASET_PATH))
raise ValueError(config.TRAIN_DATASET_PATH)
with open(config.TRAIN_DATASET_INDEX_PATH, 'rb') as f:
self.st_mj_filtered_index_list = pickle.load(f)
print(f'num of samples in ST_MJ dataset: {len(self.st_mj_filtered_index_list)}')
self.dataset_size = len(self.st_mj_filtered_index_list) // config.TRAIN_BATCH_SIZE
self.batch_size = config.TRAIN_BATCH_SIZE
def __len__(self):
return self.dataset_size
def __getitem__(self, item):
img_ret = []
text_ret = []
for i in range(item * self.batch_size, (item + 1) * self.batch_size):
index = self.st_mj_filtered_index_list[i]
img, label = get_img_from_lmdb(self.env, index)
img_ret.append(img)
text_ret.append(label)
img_ret = self.align_collector(img_ret)
text_ret, length = self.converter.encode(text_ret)
label_indices = []
for i, _ in enumerate(length):
for j in range(length[i]):
label_indices.append((i, j))
label_indices = np.array(label_indices, np.int64)
sequence_length = np.array([config.FINAL_FEATURE_WIDTH] * config.TRAIN_BATCH_SIZE, dtype=np.int32)
text_ret = text_ret.astype(np.int32)
return img_ret, label_indices, text_ret, sequence_length
class ST_MJ_Generator_batch_fixed_length_para:
def __init__(self):
self.align_collector = AlignCollate()
self.converter = CTCLabelConverter(config.CHARACTER)
self.env = lmdb.open(config.TRAIN_DATASET_PATH, max_readers=32, readonly=True, lock=False, readahead=False,
meminit=False)
if not self.env:
print('cannot create lmdb from %s' % (config.TRAIN_DATASET_PATH))
raise ValueError(config.TRAIN_DATASET_PATH)
with open(config.TRAIN_DATASET_INDEX_PATH, 'rb') as f:
self.st_mj_filtered_index_list = pickle.load(f)
print(f'num of samples in ST_MJ dataset: {len(self.st_mj_filtered_index_list)}')
self.rank_id = get_rank()
self.rank_size = get_group_size()
self.dataset_size = len(self.st_mj_filtered_index_list) // config.TRAIN_BATCH_SIZE // self.rank_size
self.batch_size = config.TRAIN_BATCH_SIZE
def __len__(self):
return self.dataset_size
def __getitem__(self, item):
img_ret = []
text_ret = []
rank_item = (item * self.rank_size) + self.rank_id
for i in range(rank_item * self.batch_size, (rank_item + 1) * self.batch_size):
index = self.st_mj_filtered_index_list[i]
img, label = get_img_from_lmdb(self.env, index)
img_ret.append(img)
text_ret.append(label)
img_ret = self.align_collector(img_ret)
text_ret, length = self.converter.encode(text_ret)
label_indices = []
for i, _ in enumerate(length):
for j in range(length[i]):
label_indices.append((i, j))
label_indices = np.array(label_indices, np.int64)
sequence_length = np.array([config.FINAL_FEATURE_WIDTH] * config.TRAIN_BATCH_SIZE, dtype=np.int32)
text_ret = text_ret.astype(np.int32)
return img_ret, label_indices, text_ret, sequence_length
def IIIT_Generator_batch():
max_len = int((26 + 1) // 2)
align_collector = AlignCollate()
converter = CTCLabelConverter(config.CHARACTER)
env = lmdb.open(config.TEST_DATASET_PATH, max_readers=32, readonly=True, lock=False, readahead=False, meminit=False)
if not env:
print('cannot create lmdb from %s' % (config.TEST_DATASET_PATH))
sys.exit(0)
with env.begin(write=False) as txn:
nSamples = int(txn.get('num-samples'.encode()))
nSamples = nSamples
# Filtering
filtered_index_list = []
for index in range(nSamples):
index += 1 # lmdb starts with 1
label_key = 'label-%09d'.encode() % index
label = txn.get(label_key).decode('utf-8')
if len(label) > max_len:
continue
illegal_sample = False
for char_item in label.lower():
if char_item not in config.CHARACTER:
illegal_sample = True
break
if illegal_sample:
continue
filtered_index_list.append(index)
img_ret = []
text_ret = []
print(f'num of samples in IIIT dataset: {len(filtered_index_list)}')
for index in filtered_index_list:
img, label = get_img_from_lmdb(env, index)
img_ret.append(img)
text_ret.append(label)
if len(img_ret) == config.TEST_BATCH_SIZE:
img_ret = align_collector(img_ret)
text_ret, length = converter.encode(text_ret)
label_indices = []
for i, _ in enumerate(length):
for j in range(length[i]):
label_indices.append((i, j))
label_indices = np.array(label_indices, np.int64)
sequence_length = np.array([26] * config.TEST_BATCH_SIZE, dtype=np.int32)
text_ret = text_ret.astype(np.int32)
yield img_ret, label_indices, text_ret, sequence_length, length
img_ret = []
text_ret = []

View File

@ -0,0 +1,88 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""generate ascend rank file"""
import os
import socket
import argparse
parser = argparse.ArgumentParser(description="ascend distribute rank.")
parser.add_argument("--rank_file", type=str, default="scripts/rank_table_8p.json", help="rank_tabel_file_path.")
def main(rank_table_file):
nproc_per_node = 8
visible_devices = ['0', '1', '2', '3', '4', '5', '6', '7']
server_id = socket.gethostbyname(socket.gethostname())
hccn_configs = open('/etc/hccn.conf', 'r').readlines()
device_ips = {}
for hccn_item in hccn_configs:
hccn_item = hccn_item.strip()
if hccn_item.startswith('address_'):
device_id, device_ip = hccn_item.split('=')
device_id = device_id.split('_')[1]
device_ips[device_id] = device_ip
print('device_id:{}, device_ip:{}'.format(device_id, device_ip))
hccn_table = {}
hccn_table['board_id'] = '0x002f' # A+K
# hccn_table['board_id'] = '0x0000' # A+X
hccn_table['chip_info'] = '910'
hccn_table['deploy_mode'] = 'lab'
hccn_table['group_count'] = '1'
hccn_table['group_list'] = []
instance_list = []
for instance_id in range(nproc_per_node):
instance = {}
instance['devices'] = []
device_id = visible_devices[instance_id]
device_ip = device_ips[device_id]
instance['devices'].append({
'device_id': device_id,
'device_ip': device_ip,
})
instance['rank_id'] = str(instance_id)
instance['server_id'] = server_id
instance_list.append(instance)
hccn_table['group_list'].append({
'device_num': str(nproc_per_node),
'server_num': '1',
'group_name': '',
'instance_count': str(nproc_per_node),
'instance_list': instance_list,
})
hccn_table['para_plane_nic_location'] = 'device'
hccn_table['para_plane_nic_name'] = []
for instance_id in range(nproc_per_node):
eth_id = visible_devices[instance_id]
hccn_table['para_plane_nic_name'].append('eth{}'.format(eth_id))
hccn_table['para_plane_nic_num'] = str(nproc_per_node)
hccn_table['status'] = 'completed'
import json
with open(rank_table_file, 'w') as table_fp:
json.dump(hccn_table, table_fp, indent=4)
if __name__ == '__main__':
args_opt = parser.parse_args()
rank_table = args_opt.rank_file
if os.path.exists(rank_table):
print('Rank table file exists.')
else:
print('Generating rank table file.')
main(rank_table)
print('Rank table file generated')

View File

@ -0,0 +1,171 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""preprocess dataset"""
import random
import pickle
import numpy as np
import lmdb
from tqdm import tqdm
def combine_lmdbs(lmdb_paths, lmdb_save_path):
max_len = int((26 + 1) // 2)
character = '0123456789abcdefghijklmnopqrstuvwxyz'
env_save = lmdb.open(
lmdb_save_path,
map_size=1099511627776)
cnt = 0
for lmdb_path in lmdb_paths:
env = lmdb.open(lmdb_path, max_readers=32, readonly=True, lock=False, readahead=False, meminit=False)
with env.begin(write=False) as txn:
nSamples = int(txn.get('num-samples'.encode()))
nSamples = nSamples
# Filtering
for index in tqdm(range(nSamples)):
index += 1 # lmdb starts with 1
label_key = 'label-%09d'.encode() % index
label = txn.get(label_key).decode('utf-8')
if len(label) > max_len:
continue
illegal_sample = False
for char_item in label.lower():
if char_item not in character:
illegal_sample = True
break
if illegal_sample:
continue
img_key = 'image-%09d'.encode() % index
imgbuf = txn.get(img_key)
with env_save.begin(write=True) as txn_save:
cnt += 1
label_key_save = 'label-%09d'.encode() % cnt
label_save = label.encode()
image_key_save = 'image-%09d'.encode() % cnt
image_save = imgbuf
txn_save.put(label_key_save, label_save)
txn_save.put(image_key_save, image_save)
nSamples = cnt
with env_save.begin(write=True) as txn_save:
txn_save.put('num-samples'.encode(), str(nSamples).encode())
def analyze_lmdb_label_length(lmdb_path, batch_size=192, num_of_combinations=1000):
label_length_dict = {}
env = lmdb.open(lmdb_path, max_readers=32, readonly=True, lock=False, readahead=False, meminit=False)
with env.begin(write=False) as txn:
nSamples = int(txn.get('num-samples'.encode()))
nSamples = nSamples
for index in tqdm(range(nSamples)):
index += 1 # lmdb starts with 1
label_key = 'label-%09d'.encode() % index
label = txn.get(label_key).decode('utf-8')
label_length = len(label)
if label_length in label_length_dict:
label_length_dict[label_length] += 1
else:
label_length_dict[label_length] = 1
sorted_label_length = sorted(label_length_dict.items(), key=lambda x: x[1], reverse=True)
label_length_sum = 0
label_num = 0
lengths = []
p = []
for l, num in sorted_label_length:
label_length_sum += l * num
label_num += num
p.append(num)
lengths.append(l)
for i, _ in enumerate(p):
p[i] /= label_num
average_overall_length = int(label_length_sum / label_num * batch_size)
def get_combinations_of_fix_length(fix_length, items, p, batch_size):
ret = []
cur_sum = 0
ret = np.random.choice(items, batch_size - 1, True, p)
cur_sum = sum(ret)
ret = list(ret)
if fix_length - cur_sum in items:
ret.append(fix_length - cur_sum)
else:
return None
return ret
result = []
while len(result) < num_of_combinations:
ret = get_combinations_of_fix_length(average_overall_length, lengths, p, batch_size)
if ret is not None:
result.append(ret)
return result
def generate_fix_shape_index_list(lmdb_path, combinations, pkl_save_path, num_of_iters=70000):
length_index_dict = {}
env = lmdb.open(lmdb_path, max_readers=32, readonly=True, lock=False, readahead=False, meminit=False)
with env.begin(write=False) as txn:
nSamples = int(txn.get('num-samples'.encode()))
nSamples = nSamples
for index in tqdm(range(nSamples)):
index += 1 # lmdb starts with 1
label_key = 'label-%09d'.encode() % index
label = txn.get(label_key).decode('utf-8')
label_length = len(label)
if label_length in length_index_dict:
length_index_dict[label_length].append(index)
else:
length_index_dict[label_length] = [index]
ret = []
for _ in range(num_of_iters):
comb = random.choice(combinations)
for l in comb:
ret.append(random.choice(length_index_dict[l]))
with open(pkl_save_path, 'wb') as f:
pickle.dump(ret, f, -1)
if __name__ == '__main__':
# step 1: combine the SynthText dataset and MJSynth dataset into a single lmdb file
print('Begin to combine multiple lmdb datasets')
combine_lmdbs(['/home/workspace/mindspore_dataset/CNNCTC_Data/1_ST/',
'/home/workspace/mindspore_dataset/CNNCTC_Data/MJ_train/'],
'/home/workspace/mindspore_dataset/CNNCTC_Data/ST_MJ')
# step 2: generate the order of input data, guarantee that the input batch shape is fixed
print('Begin to generate the index order of input data')
combination = analyze_lmdb_label_length('/home/workspace/mindspore_dataset/CNNCTC_Data/ST_MJ')
generate_fix_shape_index_list('/home/workspace/mindspore_dataset/CNNCTC_Data/ST_MJ', combination,
'/home/workspace/mindspore_dataset/CNNCTC_Data/st_mj_fixed_length_index_list.pkl')
print('Done')

View File

@ -0,0 +1,102 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""util file"""
import numpy as np
class AverageMeter():
"""Computes and stores the average and current value"""
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
class CTCLabelConverter():
""" Convert between text-label and text-index """
def __init__(self, character):
# character (str): set of the possible characters.
dict_character = list(character)
self.dict = {}
for i, char in enumerate(dict_character):
self.dict[char] = i
self.character = dict_character + ['[blank]'] # dummy '[blank]' token for CTCLoss (index 0)
self.dict['[blank]'] = len(dict_character)
def encode(self, text):
"""convert text-label into text-index.
input:
text: text labels of each image. [batch_size]
output:
text: concatenated text index for CTCLoss.
[sum(text_lengths)] = [text_index_0 + text_index_1 + ... + text_index_(n - 1)]
length: length of each text. [batch_size]
"""
length = [len(s) for s in text]
text = ''.join(text)
text = [self.dict[char] for char in text]
return np.array(text), np.array(length)
def decode(self, text_index, length):
""" convert text-index into text-label. """
texts = []
index = 0
for l in length:
t = text_index[index:index + l]
char_list = []
for i in range(l):
# if t[i] != self.dict['[blank]'] and (not (i > 0 and t[i - 1] == t[i])): # removing repeated characters and blank.
if t[i] != self.dict['[blank]'] and (
not (i > 0 and t[i - 1] == t[i])): # removing repeated characters and blank.
char_list.append(self.character[t[i]])
text = ''.join(char_list)
texts.append(text)
index += l
return texts
def reverse_encode(self, text_index, length):
""" convert text-index into text-label. """
texts = []
index = 0
for l in length:
t = text_index[index:index + l]
char_list = []
for i in range(l):
if t[i] != self.dict['[blank]']: # removing repeated characters and blank.
char_list.append(self.character[t[i]])
text = ''.join(char_list)
texts.append(text)
index += l
return texts

View File

@ -0,0 +1,100 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""cnnctc train"""
import argparse
import ast
import mindspore
from mindspore import context
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.dataset import GeneratorDataset
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig
from mindspore.train.model import Model
from mindspore.communication.management import init
from mindspore.common import set_seed
from src.config import Config_CNNCTC
from src.callback import LossCallBack
from src.dataset import ST_MJ_Generator_batch_fixed_length, ST_MJ_Generator_batch_fixed_length_para
from src.cnn_ctc import CNNCTC_Model, ctc_loss, WithLossCell
set_seed(1)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False,
save_graphs_path=".", enable_auto_mixed_precision=False)
def dataset_creator(run_distribute):
if run_distribute:
st_dataset = ST_MJ_Generator_batch_fixed_length_para()
else:
st_dataset = ST_MJ_Generator_batch_fixed_length()
ds = GeneratorDataset(st_dataset,
['img', 'label_indices', 'text', 'sequence_length'],
num_parallel_workers=8)
return ds
def train(args_opt, config):
if args_opt.run_distribute:
init()
context.set_auto_parallel_context(parallel_mode="data_parallel")
ds = dataset_creator(args_opt.run_distribute)
net = CNNCTC_Model(config.NUM_CLASS, config.HIDDEN_SIZE, config.FINAL_FEATURE_WIDTH)
net.set_train(True)
if config.CKPT_PATH != '':
param_dict = load_checkpoint(config.CKPT_PATH)
load_param_into_net(net, param_dict)
print('parameters loaded!')
else:
print('train from scratch...')
criterion = ctc_loss()
opt = mindspore.nn.RMSProp(params=net.trainable_params(), centered=True, learning_rate=config.LR_PARA,
momentum=config.MOMENTUM, loss_scale=config.LOSS_SCALE)
net = WithLossCell(net, criterion)
loss_scale_manager = mindspore.train.loss_scale_manager.FixedLossScaleManager(config.LOSS_SCALE, False)
model = Model(net, optimizer=opt, loss_scale_manager=loss_scale_manager, amp_level="O2")
callback = LossCallBack()
config_ck = CheckpointConfig(save_checkpoint_steps=config.SAVE_CKPT_PER_N_STEP,
keep_checkpoint_max=config.KEEP_CKPT_MAX_NUM)
ckpoint_cb = ModelCheckpoint(prefix="CNNCTC", config=config_ck, directory=config.SAVE_PATH)
if args_opt.device_id == 0:
model.train(config.TRAIN_EPOCHS, ds, callbacks=[callback, ckpoint_cb], dataset_sink_mode=False)
else:
model.train(config.TRAIN_EPOCHS, ds, callbacks=[callback], dataset_sink_mode=False)
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='CNNCTC arg')
parser.add_argument('--device_id', type=int, default=0, help="Device id, default is 0.")
parser.add_argument("--ckpt_path", type=str, default="", help="Pretrain file path.")
parser.add_argument("--run_distribute", type=ast.literal_eval, default=False,
help="Run distribute, default is false.")
args_cfg = parser.parse_args()
cfg = Config_CNNCTC()
if args_cfg.ckpt_path != "":
cfg.CKPT_PATH = args_cfg.ckpt_path
train(args_cfg, cfg)

View File

@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#" :===========================================================================
# ===========================================================================
"""
network config setting, will be used in train.py and eval.py
"""