forked from mindspore-Ecosystem/mindspore
!6383 Add modelzoo CNNCTC Network.
Merge pull request !6383 from linqingke/fasterrcnn
This commit is contained in:
commit
737e27d721
|
@ -0,0 +1,354 @@
|
|||
# Contents
|
||||
|
||||
- [CNNCTC Description](#CNNCTC-description)
|
||||
- [Model Architecture](#model-architecture)
|
||||
- [Dataset](#dataset)
|
||||
- [Features](#features)
|
||||
- [Mixed Precision](#mixed-precision)
|
||||
- [Environment Requirements](#environment-requirements)
|
||||
- [Quick Start](#quick-start)
|
||||
- [Script Description](#script-description)
|
||||
- [Script and Sample Code](#script-and-sample-code)
|
||||
- [Script Parameters](#script-parameters)
|
||||
- [Training Process](#training-process)
|
||||
- [Training](#training)
|
||||
- [Distributed Training](#distributed-training)
|
||||
- [Evaluation Process](#evaluation-process)
|
||||
- [Evaluation](#evaluation)
|
||||
- [Model Description](#model-description)
|
||||
- [Performance](#performance)
|
||||
- [Evaluation Performance](#evaluation-performance)
|
||||
- [Inference Performance](#evaluation-performance)
|
||||
- [How to use](#how-to-use)
|
||||
- [Inference](#inference)
|
||||
- [Continue Training on the Pretrained Model](#continue-training-on-the-pretrained-model)
|
||||
- [Transfer Learning](#transfer-learning)
|
||||
- [Description of Random Situation](#description-of-random-situation)
|
||||
- [ModelZoo Homepage](#modelzoo-homepage)
|
||||
|
||||
|
||||
# [CNNCTC Description](#contents)
|
||||
This paper proposes three major contributions to addresses scene text recognition (STR).
|
||||
First, we examine the inconsistencies of training and evaluation datasets, and the performance gap results from inconsistencies.
|
||||
Second, we introduce a unified four-stage STR framework that most existing STR models fit into.
|
||||
Using this framework allows for the extensive evaluation of previously proposed STR modules and the discovery of previously
|
||||
unexplored module combinations. Third, we analyze the module-wise contributions to performance in terms of accuracy, speed,
|
||||
and memory demand, under one consistent set of training and evaluation datasets. Such analyses clean up the hindrance on the current
|
||||
comparisons to understand the performance gain of the existing modules.
|
||||
[Paper](https://arxiv.org/abs/1904.01906): J. Baek, G. Kim, J. Lee, S. Park, D. Han, S. Yun, S. J. Oh, and H. Lee, “What is wrong with scene text recognition model comparisons? dataset and model analysis,” ArXiv, vol. abs/1904.01906, 2019.
|
||||
|
||||
# [Model Architecture](#contents)
|
||||
This is an example of training CNN+CTC model for text recognition on MJSynth and SynthText dataset with MindSpore.
|
||||
|
||||
|
||||
|
||||
# [Dataset](#contents)
|
||||
|
||||
The [MJSynth](https://www.robots.ox.ac.uk/~vgg/data/text/) and [SynthText](https://github.com/ankush-me/SynthText) dataset are used for model training. The [The IIIT 5K-word dataset](https://cvit.iiit.ac.in/research/projects/cvit-projects/the-iiit-5k-word-dataset) dataset is used for evaluation.
|
||||
|
||||
- step 1:
|
||||
All the datasets have been preprocessed and stored in .lmdb format and can be downloaded [**HERE**](https://drive.google.com/drive/folders/192UfE9agQUMNq6AgU3_E05_FcPZK4hyt).
|
||||
|
||||
- step 2:
|
||||
Uncompress the downloaded file, rename the MJSynth dataset as MJ, the SynthText dataset as ST and the IIIT dataset as IIIT.
|
||||
|
||||
- step 3:
|
||||
Move above mentioned three datasets into `cnnctc_data` folder, and the structure should be as below:
|
||||
```
|
||||
|--- CNNCTC/
|
||||
|--- cnnctc_data/
|
||||
|--- ST/
|
||||
data.mdb
|
||||
lock.mdb
|
||||
|--- MJ/
|
||||
data.mdb
|
||||
lock.mdb
|
||||
|--- IIIT/
|
||||
data.mdb
|
||||
lock.mdb
|
||||
|
||||
......
|
||||
```
|
||||
|
||||
- step 4:
|
||||
Preprocess the dataset by running:
|
||||
```
|
||||
python src/preprocess_dataset.py
|
||||
```
|
||||
|
||||
This takes around 75 minutes.
|
||||
|
||||
# [Features](#contents)
|
||||
|
||||
## Mixed Precision
|
||||
|
||||
The [mixed precision](https://www.mindspore.cn/tutorial/zh-CN/master/advanced_use/mixed_precision.html) training method accelerates the deep learning neural network training process by using both the single-precision and half-precision data formats, and maintains the network precision achieved by the single-precision training at the same time. Mixed precision training can accelerate the computation process, reduce memory usage, and enable a larger model or batch size to be trained on specific hardware.
|
||||
For FP16 operators, if the input data type is FP32, the backend of MindSpore will automatically handle it with reduced precision. Users could check the reduced-precision operators by enabling INFO log and then searching ‘reduce precision’.
|
||||
|
||||
|
||||
|
||||
# [Environment Requirements](#contents)
|
||||
|
||||
- Hardware(Ascend)
|
||||
|
||||
- Prepare hardware environment with Ascend processor. If you want to try Ascend , please send the [application form](https://obs-9be7.obs.cn-east-2.myhuaweicloud.com/file/other/Ascend%20Model%20Zoo%E4%BD%93%E9%AA%8C%E8%B5%84%E6%BA%90%E7%94%B3%E8%AF%B7%E8%A1%A8.docx) to ascend@huawei.com. Once approved, you can get the resources.
|
||||
- Framework
|
||||
|
||||
- [MindSpore](https://www.mindspore.cn/install/en)
|
||||
- For more information, please check the resources below:
|
||||
- [MindSpore tutorials](https://www.mindspore.cn/tutorial/zh-CN/master/index.html)
|
||||
|
||||
- [MindSpore API](https://www.mindspore.cn/api/zh-CN/master/index.html)
|
||||
|
||||
|
||||
|
||||
|
||||
# [Quick Start](#contents)
|
||||
|
||||
- Install dependencies:
|
||||
```
|
||||
pip install lmdb
|
||||
pip install Pillow
|
||||
pip install tqdm
|
||||
pip install six
|
||||
```
|
||||
|
||||
- Standalone Training:
|
||||
```
|
||||
bash scripts/run_standalone_train_ascend.sh $PRETRAINED_CKPT
|
||||
```
|
||||
|
||||
- Distributed Training:
|
||||
```
|
||||
bash scripts/run_distribute_train_ascend.sh $RANK_TABLE_FILE $PRETRAINED_CKPT
|
||||
```
|
||||
|
||||
- Evaluation:
|
||||
```
|
||||
bash scripts/run_eval_ascend.sh $TRAINED_CKPT
|
||||
```
|
||||
|
||||
# [Script Description](#contents)
|
||||
|
||||
## [Script and Sample Code](#contents)
|
||||
The entire code structure is as following:
|
||||
```
|
||||
|--- CNNCTC/
|
||||
|---README.md // descriptions about cnnctc
|
||||
|---train.py // train scripts
|
||||
|---eval.py // eval scripts
|
||||
|---scripts
|
||||
|---run_standalone_train_ascend.sh // shell script for standalone on ascend
|
||||
|---run_distribute_train_ascend.sh // shell script for distributed on ascend
|
||||
|---run_eval_ascend.sh // shell script for eval on ascend
|
||||
|---src
|
||||
|---__init__.py // init file
|
||||
|---cnn_ctc.py // cnn_ctc network
|
||||
|---config.py // total config
|
||||
|---callback.py // loss callback file
|
||||
|---dataset.py // process dataset
|
||||
|---util.py // routine operation
|
||||
|---generate_hccn_file.py // generate distribute json file
|
||||
|---preprocess_dataset.py // preprocess dataset
|
||||
|
||||
```
|
||||
|
||||
|
||||
## [Script Parameters](#contents)
|
||||
Parameters for both training and evaluation can be set in `config.py`.
|
||||
|
||||
Arguments:
|
||||
* `--CHARACTER`: Character labels.
|
||||
* `--NUM_CLASS`: The number of classes including all character labels and the <blank> label for CTCLoss.
|
||||
* `--HIDDEN_SIZE`: Model hidden size.
|
||||
* `--FINAL_FEATURE_WIDTH`: The number of features.
|
||||
* `--IMG_H`: The height of input image.
|
||||
* `--IMG_W`: The width of input image.
|
||||
* `--TRAIN_DATASET_PATH`: The path to training dataset.
|
||||
* `--TRAIN_DATASET_INDEX_PATH`: The path to training dataset index file which determines the order .
|
||||
* `--TRAIN_BATCH_SIZE`: Training batch size. The batch size and index file must ensure input data is in fixed shape.
|
||||
* `--TRAIN_DATASET_SIZE`: Training dataset size.
|
||||
* `--TEST_DATASET_PATH`: The path to test dataset.
|
||||
* `--TEST_BATCH_SIZE`: Test batch size.
|
||||
* `--TEST_DATASET_SIZE`:Test dataset size.
|
||||
* `--TRAIN_EPOCHS`:Total training epochs.
|
||||
* `--CKPT_PATH`:The path to model checkpoint file, can be used to resume training and evaluation.
|
||||
* `--SAVE_PATH`:The path to save model checkpoint file.
|
||||
* `--LR`:Learning rate for standalone training.
|
||||
* `--LR_PARA`:Learning rate for distributed training.
|
||||
* `--MOMENTUM`:Momentum.
|
||||
* `--LOSS_SCALE`:Loss scale to prevent gradient underflow.
|
||||
* `--SAVE_CKPT_PER_N_STEP`:Save model checkpoint file per N steps.
|
||||
* `--KEEP_CKPT_MAX_NUM`:The maximum number of saved model checkpoint file.
|
||||
|
||||
## [Training Process](#contents)
|
||||
|
||||
### Training
|
||||
|
||||
- Standalone Training:
|
||||
```
|
||||
bash scripts/run_standalone_train_ascend.sh $PRETRAINED_CKPT
|
||||
```
|
||||
|
||||
Results and checkpoints are written to `./train` folder. Log can be found in `./train/log` and loss values are recorded in `./train/loss.log`.
|
||||
|
||||
`$PRETRAINED_CKPT` is the path to model checkpoint and it is **optional**. If none is given the model will be trained from scratch.
|
||||
|
||||
- Distributed Training:
|
||||
```
|
||||
bash scripts/run_distribute_train_ascend.sh $RANK_TABLE_FILE $PRETRAINED_CKPT
|
||||
```
|
||||
|
||||
Results and checkpoints are written to `./train_parallel_{i}` folder for device `i` respectively.
|
||||
Log can be found in `./train_parallel_{i}/log_{i}.log` and loss values are recorded in `./train_parallel_{i}/loss.log`.
|
||||
|
||||
`$RANK_TABLE_FILE` is needed when you are running a distribute task on ascend.
|
||||
`$PATH_TO_CHECKPOINT` is the path to model checkpoint and it is **optional**. If none is given the model will be trained from scratch.
|
||||
|
||||
### Training Result
|
||||
|
||||
Training result will be stored in the example path, whose folder name begins with "train" or "train_parallel". You can find checkpoint file together with result like the followings in loss.log.
|
||||
|
||||
|
||||
```
|
||||
# distribute training result(8p)
|
||||
epoch: 1 step: 1 , loss is 76.25, average time per step is 0.335177839748392712
|
||||
epoch: 1 step: 2 , loss is 73.46875, average time per step is 0.36798572540283203
|
||||
epoch: 1 step: 3 , loss is 69.46875, average time per step is 0.3429678678512573
|
||||
epoch: 1 step: 4 , loss is 64.3125, average time per step is 0.33512671788533527
|
||||
epoch: 1 step: 5 , loss is 58.375, average time per step is 0.33149147033691406
|
||||
epoch: 1 step: 6 , loss is 52.7265625, average time per step is 0.3292975425720215
|
||||
...
|
||||
epoch: 1 step: 8689 , loss is 9.706798802612482, average time per step is 0.3184656601312549
|
||||
epoch: 1 step: 8690 , loss is 9.70612545289855, average time per step is 0.3184725407765116
|
||||
epoch: 1 step: 8691 , loss is 9.70695776049204, average time per step is 0.31847309686135555
|
||||
epoch: 1 step: 8692 , loss is 9.707279624277456, average time per step is 0.31847339290613375
|
||||
epoch: 1 step: 8693 , loss is 9.70763437950938, average time per step is 0.3184720295013031
|
||||
epoch: 1 step: 8694 , loss is 9.707695425072046, average time per step is 0.31847410284595573
|
||||
epoch: 1 step: 8695 , loss is 9.708408273381295, average time per step is 0.31847338271072345
|
||||
epoch: 1 step: 8696 , loss is 9.708703753591953, average time per step is 0.3184726025560777
|
||||
epoch: 1 step: 8697 , loss is 9.709536406025824, average time per step is 0.31847212061114694
|
||||
epoch: 1 step: 8698 , loss is 9.708542263610315, average time per step is 0.3184715309307257
|
||||
```
|
||||
|
||||
## [Evaluation Process](#contents)
|
||||
|
||||
### Evaluation
|
||||
- Evaluation:
|
||||
```
|
||||
bash scripts/run_eval_ascend.sh $TRAINED_CKPT
|
||||
```
|
||||
|
||||
The model will be evaluated on the IIIT dataset, sample results and overall accuracy will be printed.
|
||||
|
||||
|
||||
# [Model Description](#contents)
|
||||
## [Performance](#contents)
|
||||
|
||||
### Training Performance
|
||||
|
||||
| Parameters | FasterRcnn |
|
||||
| -------------------------- | ----------------------------------------------------------- |
|
||||
| Model Version | V1 |
|
||||
| Resource | Ascend 910 ;CPU 2.60GHz,192cores;Memory,755G |
|
||||
| uploaded Date | 09/28/2020 (month/day/year) |
|
||||
| MindSpore Version | 1.0.0 |
|
||||
| Dataset | MJSynth,SynthText |
|
||||
| Training Parameters | epoch=3, batch_size=192 |
|
||||
| Optimizer | RMSProp |
|
||||
| Loss Function | CTCLoss |
|
||||
| Speed | 1pc: 300 ms/step; 8pcs: 310 ms/step |
|
||||
| Total time | 1pc: 18 hours; 8pcs: 2.3 hours |
|
||||
| Parameters (M) | 177 |
|
||||
| Scripts | https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/cnnctc |
|
||||
|
||||
|
||||
### Evaluation Performance
|
||||
|
||||
| Parameters | FasterRcnn |
|
||||
| ------------------- | --------------------------- |
|
||||
| Model Version | V1 |
|
||||
| Resource | Ascend 910 |
|
||||
| Uploaded Date | 09/28/2020 (month/day/year) |
|
||||
| MindSpore Version | 1.0.0 |
|
||||
| Dataset | IIIT5K |
|
||||
| batch_size | 192 |
|
||||
| outputs | Accuracy |
|
||||
| Accuracy | 85% |
|
||||
| Model for inference | 675M (.ckpt file) |
|
||||
|
||||
## [How to use](#contents)
|
||||
### Inference
|
||||
|
||||
If you need to use the trained model to perform inference on multiple hardware platforms, such as GPU, Ascend 910 or Ascend 310, you can refer to this [Link](https://www.mindspore.cn/tutorial/zh-CN/master/advanced_use/network_migration.html). Following the steps below, this is a simple example:
|
||||
|
||||
- Running on Ascend
|
||||
|
||||
```
|
||||
# Set context
|
||||
context.set_context(mode=context.GRAPH_HOME, device_target=cfg.device_target)
|
||||
context.set_context(device_id=cfg.device_id)
|
||||
|
||||
# Load unseen dataset for inference
|
||||
dataset = dataset.create_dataset(cfg.data_path, 1, False)
|
||||
|
||||
# Define model
|
||||
net = CNNCTC(cfg.NUM_CLASS, cfg.HIDDEN_SIZE, cfg.FINAL_FEATURE_WIDTH)
|
||||
opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 0.01,
|
||||
cfg.momentum, weight_decay=cfg.weight_decay)
|
||||
loss = P.CTCLoss(preprocess_collapse_repeated=False,
|
||||
ctc_merge_repeated=True,
|
||||
ignore_longer_outputs_than_inputs=False)
|
||||
model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'})
|
||||
|
||||
# Load pre-trained model
|
||||
param_dict = load_checkpoint(cfg.checkpoint_path)
|
||||
load_param_into_net(net, param_dict)
|
||||
net.set_train(False)
|
||||
|
||||
# Make predictions on the unseen dataset
|
||||
acc = model.eval(dataset)
|
||||
print("accuracy: ", acc)
|
||||
```
|
||||
|
||||
### Continue Training on the Pretrained Model
|
||||
|
||||
- running on Ascend
|
||||
|
||||
```
|
||||
# Load dataset
|
||||
dataset = create_dataset(cfg.data_path, 1)
|
||||
batch_num = dataset.get_dataset_size()
|
||||
|
||||
# Define model
|
||||
net = CNNCTC(cfg.NUM_CLASS, cfg.HIDDEN_SIZE, cfg.FINAL_FEATURE_WIDTH)
|
||||
# Continue training if set pre_trained to be True
|
||||
if cfg.pre_trained:
|
||||
param_dict = load_checkpoint(cfg.checkpoint_path)
|
||||
load_param_into_net(net, param_dict)
|
||||
lr = lr_steps(0, lr_max=cfg.lr_init, total_epochs=cfg.epoch_size,
|
||||
steps_per_epoch=batch_num)
|
||||
opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()),
|
||||
Tensor(lr), cfg.momentum, weight_decay=cfg.weight_decay)
|
||||
loss = P.CTCLoss(preprocess_collapse_repeated=False,
|
||||
ctc_merge_repeated=True,
|
||||
ignore_longer_outputs_than_inputs=False)
|
||||
model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'},
|
||||
amp_level="O2", keep_batchnorm_fp32=False, loss_scale_manager=None)
|
||||
|
||||
# Set callbacks
|
||||
config_ck = CheckpointConfig(save_checkpoint_steps=batch_num * 5,
|
||||
keep_checkpoint_max=cfg.keep_checkpoint_max)
|
||||
time_cb = TimeMonitor(data_size=batch_num)
|
||||
ckpoint_cb = ModelCheckpoint(prefix="train_googlenet_cifar10", directory="./",
|
||||
config=config_ck)
|
||||
loss_cb = LossMonitor()
|
||||
|
||||
# Start training
|
||||
model.train(cfg.epoch_size, dataset, callbacks=[time_cb, ckpoint_cb, loss_cb])
|
||||
print("train success")
|
||||
```
|
||||
|
||||
|
||||
# [ModelZoo Homepage](#contents)
|
||||
Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo).
|
|
@ -0,0 +1,109 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""cnnctc eval"""
|
||||
|
||||
import argparse
|
||||
import time
|
||||
import numpy as np
|
||||
|
||||
from mindspore import Tensor, context
|
||||
import mindspore.common.dtype as mstype
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from mindspore.dataset import GeneratorDataset
|
||||
|
||||
from src.util import CTCLabelConverter, AverageMeter
|
||||
from src.config import Config_CNNCTC
|
||||
from src.dataset import IIIT_Generator_batch
|
||||
from src.cnn_ctc import CNNCTC_Model
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False,
|
||||
save_graphs_path=".", enable_auto_mixed_precision=False)
|
||||
|
||||
def test_dataset_creator():
|
||||
ds = GeneratorDataset(IIIT_Generator_batch, ['img', 'label_indices', 'text', 'sequence_length', 'label_str'])
|
||||
return ds
|
||||
|
||||
|
||||
def test(config):
|
||||
ds = test_dataset_creator()
|
||||
|
||||
net = CNNCTC_Model(config.NUM_CLASS, config.HIDDEN_SIZE, config.FINAL_FEATURE_WIDTH)
|
||||
|
||||
ckpt_path = config.CKPT_PATH
|
||||
param_dict = load_checkpoint(ckpt_path)
|
||||
load_param_into_net(net, param_dict)
|
||||
print('parameters loaded! from: ', ckpt_path)
|
||||
|
||||
converter = CTCLabelConverter(config.CHARACTER)
|
||||
|
||||
model_run_time = AverageMeter()
|
||||
npu_to_cpu_time = AverageMeter()
|
||||
postprocess_time = AverageMeter()
|
||||
|
||||
count = 0
|
||||
correct_count = 0
|
||||
for data in ds.create_tuple_iterator():
|
||||
img, _, text, _, length = data
|
||||
|
||||
img_tensor = Tensor(img, mstype.float32)
|
||||
|
||||
model_run_begin = time.time()
|
||||
model_predict = net(img_tensor)
|
||||
model_run_end = time.time()
|
||||
model_run_time.update(model_run_end - model_run_begin)
|
||||
|
||||
npu_to_cpu_begin = time.time()
|
||||
model_predict = np.squeeze(model_predict.asnumpy())
|
||||
npu_to_cpu_end = time.time()
|
||||
npu_to_cpu_time.update(npu_to_cpu_end - npu_to_cpu_begin)
|
||||
|
||||
postprocess_begin = time.time()
|
||||
preds_size = np.array([model_predict.shape[1]] * config.TEST_BATCH_SIZE)
|
||||
preds_index = np.argmax(model_predict, 2)
|
||||
preds_index = np.reshape(preds_index, [-1])
|
||||
preds_str = converter.decode(preds_index, preds_size)
|
||||
postprocess_end = time.time()
|
||||
postprocess_time.update(postprocess_end - postprocess_begin)
|
||||
|
||||
label_str = converter.reverse_encode(text.asnumpy(), length.asnumpy())
|
||||
|
||||
if count == 0:
|
||||
model_run_time.reset()
|
||||
npu_to_cpu_time.reset()
|
||||
postprocess_time.reset()
|
||||
else:
|
||||
print('---------model run time--------', model_run_time.avg)
|
||||
print('---------npu_to_cpu run time--------', npu_to_cpu_time.avg)
|
||||
print('---------postprocess run time--------', postprocess_time.avg)
|
||||
|
||||
print("Prediction samples: \n", preds_str[:5])
|
||||
print("Ground truth: \n", label_str[:5])
|
||||
for pred, label in zip(preds_str, label_str):
|
||||
if pred == label:
|
||||
correct_count += 1
|
||||
count += 1
|
||||
|
||||
print('accuracy: ', correct_count / count)
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description="FasterRcnn training")
|
||||
parser.add_argument('--device_id', type=int, default=0, help="Device id, default is 0.")
|
||||
parser.add_argument("--ckpt_path", type=str, default="", help="trained file path.")
|
||||
args_opt = parser.parse_args()
|
||||
|
||||
cfg = Config_CNNCTC()
|
||||
if args_opt.ckpt_path != "":
|
||||
cfg.CKPT_PATH = args_opt.ckpt_path
|
||||
test(cfg)
|
|
@ -0,0 +1,57 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
current_exec_path=$(pwd)
|
||||
echo ${current_exec_path}
|
||||
|
||||
get_real_path(){
|
||||
if [ "${1:0:1}" == "/" ]; then
|
||||
echo "$1"
|
||||
else
|
||||
echo "$(realpath -m $PWD/$1)"
|
||||
fi
|
||||
}
|
||||
|
||||
PATH1=$(get_real_path $1)
|
||||
echo $PATH1
|
||||
|
||||
PATH2=$(get_real_path $2)
|
||||
echo $PATH2
|
||||
|
||||
python ${current_exec_path}/src/generate_hccn_file.py --rank_file=$PATH1
|
||||
export RANK_TABLE_FILE=$PATH1
|
||||
export RANK_SIZE=8
|
||||
ulimit -u unlimited
|
||||
for((i=0;i<$RANK_SIZE;i++));
|
||||
do
|
||||
rm ./train_parallel_$i/ -rf
|
||||
mkdir ./train_parallel_$i
|
||||
cp ./*.py ./train_parallel_$i
|
||||
cp ./scripts/*.sh ./train_parallel_$i
|
||||
cp -r ./src ./train_parallel_$i
|
||||
cd ./train_parallel_$i || exit
|
||||
export RANK_ID=$i
|
||||
export DEVICE_ID=$i
|
||||
echo "start training for rank $RANK_ID, device $DEVICE_ID"
|
||||
if [ -f $PATH2 ]
|
||||
then
|
||||
python train.py --device_id=$i --ckpt_path=$PATH2 --run_distribute=True >log_$i.log 2>&1 &
|
||||
else
|
||||
python train.py --device_id=$i --run_distribute=True >log_$i.log 2>&1 &
|
||||
fi
|
||||
cd .. || exit
|
||||
done
|
||||
|
|
@ -0,0 +1,54 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
if [ $# -ne 1 ]
|
||||
then
|
||||
echo "Usage: sh run_eval_ascend.sh [TRAINED_CKPT]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
get_real_path(){
|
||||
if [ "${1:0:1}" == "/" ]; then
|
||||
echo "$1"
|
||||
else
|
||||
echo "$(realpath -m $PWD/$1)"
|
||||
fi
|
||||
}
|
||||
|
||||
PATH1=$(get_real_path $1)
|
||||
echo $PATH1
|
||||
if [ ! -f $PATH1 ]
|
||||
then
|
||||
echo "error: TRAINED_CKPT=$PATH1 is not a file"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
ulimit -u unlimited
|
||||
export DEVICE_ID=0
|
||||
|
||||
if [ -d "eval" ];
|
||||
then
|
||||
rm -rf ./eval
|
||||
fi
|
||||
mkdir ./eval
|
||||
cp ./*.py ./eval
|
||||
cp ./scripts/*.sh ./eval
|
||||
cp -r ./src ./eval
|
||||
cd ./eval || exit
|
||||
echo "start infering for device $DEVICE_ID"
|
||||
env > env.log
|
||||
python eval.py --device_id=$DEVICE_ID --ckpt_path=$PATH1 &> log &
|
||||
cd .. || exit
|
|
@ -0,0 +1,45 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
get_real_path(){
|
||||
if [ "${1:0:1}" == "/" ]; then
|
||||
echo "$1"
|
||||
else
|
||||
echo "$(realpath -m $PWD/$1)"
|
||||
fi
|
||||
}
|
||||
PATH1=$(get_real_path $1)
|
||||
|
||||
ulimit -u unlimited
|
||||
|
||||
if [ -d "train" ];
|
||||
then
|
||||
rm -rf ./train
|
||||
fi
|
||||
mkdir ./train
|
||||
cp ./*.py ./train
|
||||
cp ./scripts/*.sh ./train
|
||||
cp -r ./src ./train
|
||||
cd ./train || exit
|
||||
echo "start training for device $DEVICE_ID"
|
||||
env > env.log
|
||||
if [ -f $PATH1 ]
|
||||
then
|
||||
python train.py --device_id=$DEVICE_ID --ckpt_path=$PATH1 --run_distribute=False &> log &
|
||||
else
|
||||
python train.py --device_id=$DEVICE_ID --run_distribute=False &> log &
|
||||
fi
|
||||
cd .. || exit
|
|
@ -0,0 +1,15 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""src init file"""
|
|
@ -0,0 +1,71 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""loss callback"""
|
||||
|
||||
import time
|
||||
from mindspore.train.callback import Callback
|
||||
from .util import AverageMeter
|
||||
|
||||
class LossCallBack(Callback):
|
||||
"""
|
||||
Monitor the loss in training.
|
||||
|
||||
If the loss is NAN or INF terminating training.
|
||||
|
||||
Note:
|
||||
If per_print_times is 0 do not print loss.
|
||||
|
||||
Args:
|
||||
per_print_times (int): Print loss every times. Default: 1.
|
||||
"""
|
||||
|
||||
def __init__(self, per_print_times=1):
|
||||
super(LossCallBack, self).__init__()
|
||||
if not isinstance(per_print_times, int) or per_print_times < 0:
|
||||
raise ValueError("print_step must be int and >= 0.")
|
||||
self._per_print_times = per_print_times
|
||||
self.loss_avg = AverageMeter()
|
||||
self.timer = AverageMeter()
|
||||
self.start_time = time.time()
|
||||
|
||||
def step_end(self, run_context):
|
||||
cb_params = run_context.original_args()
|
||||
|
||||
loss = cb_params.net_outputs.asnumpy()
|
||||
|
||||
cur_step_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num + 1
|
||||
cur_num = cb_params.cur_step_num
|
||||
|
||||
if cur_step_in_epoch % 2000 == 1:
|
||||
self.loss_avg = AverageMeter()
|
||||
self.timer = AverageMeter()
|
||||
self.start_time = time.time()
|
||||
else:
|
||||
self.timer.update(time.time() - self.start_time)
|
||||
self.start_time = time.time()
|
||||
|
||||
self.loss_avg.update(loss)
|
||||
|
||||
if self._per_print_times != 0 and cur_num % self._per_print_times == 0:
|
||||
loss_file = open("./loss.log", "a+")
|
||||
loss_file.write("epoch: %s step: %s , loss is %s, average time per step is %s" % (
|
||||
cb_params.cur_epoch_num, cur_step_in_epoch,
|
||||
self.loss_avg.avg, self.timer.avg))
|
||||
loss_file.write("\n")
|
||||
loss_file.close()
|
||||
|
||||
print("epoch: %s step: %s , loss is %s, average time per step is %s" % (
|
||||
cb_params.cur_epoch_num, cur_step_in_epoch,
|
||||
self.loss_avg.avg, self.timer.avg))
|
|
@ -0,0 +1,255 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""cnn_ctc network define"""
|
||||
|
||||
import mindspore.nn as nn
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.common.initializer import TruncatedNormal, initializer
|
||||
import mindspore.common.dtype as mstype
|
||||
|
||||
class CNNCTC_Model(nn.Cell):
|
||||
|
||||
def __init__(self, num_class, hidden_size, final_feature_width):
|
||||
super(CNNCTC_Model, self).__init__()
|
||||
|
||||
self.num_class = num_class
|
||||
self.hidden_size = hidden_size
|
||||
self.final_feature_width = final_feature_width
|
||||
|
||||
self.FeatureExtraction = ResNet_FeatureExtractor()
|
||||
self.Prediction = nn.Dense(self.hidden_size, self.num_class)
|
||||
|
||||
self.transpose = P.Transpose()
|
||||
self.reshape = P.Reshape()
|
||||
|
||||
def construct(self, x):
|
||||
x = self.FeatureExtraction(x)
|
||||
x = self.transpose(x, (0, 3, 1, 2)) # [b, c, h, w] -> [b, w, c, h]
|
||||
|
||||
x = self.reshape(x, (-1, self.hidden_size))
|
||||
x = self.Prediction(x)
|
||||
x = self.reshape(x, (-1, self.final_feature_width, self.num_class))
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class WithLossCell(nn.Cell):
|
||||
|
||||
def __init__(self, backbone, loss_fn):
|
||||
super(WithLossCell, self).__init__(auto_prefix=False)
|
||||
self._backbone = backbone
|
||||
self._loss_fn = loss_fn
|
||||
|
||||
def construct(self, img, label_indices, text, sequence_length):
|
||||
model_predict = self._backbone(img)
|
||||
return self._loss_fn(model_predict, label_indices, text, sequence_length)
|
||||
|
||||
@property
|
||||
def backbone_network(self):
|
||||
return self._backbone
|
||||
|
||||
class ctc_loss(nn.Cell):
|
||||
|
||||
def __init__(self):
|
||||
super(ctc_loss, self).__init__()
|
||||
|
||||
self.loss = P.CTCLoss(preprocess_collapse_repeated=False,
|
||||
ctc_merge_repeated=True,
|
||||
ignore_longer_outputs_than_inputs=False)
|
||||
|
||||
self.mean = P.ReduceMean()
|
||||
self.transpose = P.Transpose()
|
||||
self.reshape = P.Reshape()
|
||||
|
||||
def construct(self, inputs, labels_indices, labels_values, sequence_length):
|
||||
inputs = self.transpose(inputs, (1, 0, 2))
|
||||
|
||||
loss, _ = self.loss(inputs, labels_indices, labels_values, sequence_length)
|
||||
|
||||
loss = self.mean(loss)
|
||||
return loss
|
||||
|
||||
|
||||
class ResNet_FeatureExtractor(nn.Cell):
|
||||
def __init__(self):
|
||||
super(ResNet_FeatureExtractor, self).__init__()
|
||||
self.ConvNet = ResNet(3, 512, BasicBlock, [1, 2, 5, 3])
|
||||
|
||||
def construct(self, featuremap):
|
||||
return self.ConvNet(featuremap)
|
||||
|
||||
|
||||
class ResNet(nn.Cell):
|
||||
def __init__(self, input_channel, output_channel, block, layers):
|
||||
super(ResNet, self).__init__()
|
||||
|
||||
self.output_channel_block = [int(output_channel / 4), int(output_channel / 2), output_channel, output_channel]
|
||||
|
||||
self.inplanes = int(output_channel / 8)
|
||||
self.conv0_1 = ms_conv3x3(input_channel, int(output_channel / 16), stride=1, padding=1, pad_mode='pad')
|
||||
self.bn0_1 = ms_fused_bn(int(output_channel / 16))
|
||||
self.conv0_2 = ms_conv3x3(int(output_channel / 16), self.inplanes, stride=1, padding=1, pad_mode='pad')
|
||||
self.bn0_2 = ms_fused_bn(self.inplanes)
|
||||
self.relu = P.ReLU()
|
||||
|
||||
self.maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2, pad_mode='valid')
|
||||
self.layer1 = self._make_layer(block, self.output_channel_block[0], layers[0])
|
||||
self.conv1 = ms_conv3x3(self.output_channel_block[0], self.output_channel_block[0], stride=1, padding=1,
|
||||
pad_mode='pad')
|
||||
self.bn1 = ms_fused_bn(self.output_channel_block[0])
|
||||
|
||||
self.maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2, pad_mode='valid')
|
||||
self.layer2 = self._make_layer(block, self.output_channel_block[1], layers[1])
|
||||
self.conv2 = ms_conv3x3(self.output_channel_block[1], self.output_channel_block[1], stride=1, padding=1,
|
||||
pad_mode='pad')
|
||||
self.bn2 = ms_fused_bn(self.output_channel_block[1])
|
||||
|
||||
self.pad = P.Pad(((0, 0), (0, 0), (0, 0), (1, 1)))
|
||||
self.maxpool3 = nn.MaxPool2d(kernel_size=2, stride=(2, 1), pad_mode='valid')
|
||||
self.layer3 = self._make_layer(block, self.output_channel_block[2], layers[2])
|
||||
self.conv3 = ms_conv3x3(self.output_channel_block[2], self.output_channel_block[2], stride=1, padding=1,
|
||||
pad_mode='pad')
|
||||
self.bn3 = ms_fused_bn(self.output_channel_block[2])
|
||||
|
||||
self.layer4 = self._make_layer(block, self.output_channel_block[3], layers[3])
|
||||
self.conv4_1 = ms_conv2x2(self.output_channel_block[3], self.output_channel_block[3], stride=(2, 1),
|
||||
pad_mode='valid')
|
||||
self.bn4_1 = ms_fused_bn(self.output_channel_block[3])
|
||||
|
||||
self.conv4_2 = ms_conv2x2(self.output_channel_block[3], self.output_channel_block[3], stride=1, padding=0,
|
||||
pad_mode='valid')
|
||||
self.bn4_2 = ms_fused_bn(self.output_channel_block[3])
|
||||
|
||||
def _make_layer(self, block, planes, blocks, stride=1):
|
||||
downsample = None
|
||||
if stride != 1 or self.inplanes != planes * block.expansion:
|
||||
downsample = nn.SequentialCell(
|
||||
[ms_conv1x1(self.inplanes, planes * block.expansion, stride=stride),
|
||||
ms_fused_bn(planes * block.expansion)]
|
||||
)
|
||||
|
||||
layers = []
|
||||
layers.append(block(self.inplanes, planes, stride, downsample))
|
||||
self.inplanes = planes * block.expansion
|
||||
for _ in range(1, blocks):
|
||||
layers.append(block(self.inplanes, planes))
|
||||
|
||||
return nn.SequentialCell(layers)
|
||||
|
||||
def construct(self, x):
|
||||
x = self.conv0_1(x)
|
||||
x = self.bn0_1(x)
|
||||
x = self.relu(x)
|
||||
x = self.conv0_2(x)
|
||||
x = self.bn0_2(x)
|
||||
x = self.relu(x)
|
||||
|
||||
x = self.maxpool1(x)
|
||||
x = self.layer1(x)
|
||||
x = self.conv1(x)
|
||||
x = self.bn1(x)
|
||||
x = self.relu(x)
|
||||
|
||||
x = self.maxpool2(x)
|
||||
x = self.layer2(x)
|
||||
x = self.conv2(x)
|
||||
x = self.bn2(x)
|
||||
x = self.relu(x)
|
||||
|
||||
x = self.pad(x)
|
||||
x = self.maxpool3(x)
|
||||
x = self.layer3(x)
|
||||
x = self.conv3(x)
|
||||
x = self.bn3(x)
|
||||
x = self.relu(x)
|
||||
|
||||
x = self.layer4(x)
|
||||
x = self.pad(x)
|
||||
x = self.conv4_1(x)
|
||||
x = self.bn4_1(x)
|
||||
x = self.relu(x)
|
||||
x = self.conv4_2(x)
|
||||
x = self.bn4_2(x)
|
||||
x = self.relu(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class BasicBlock(nn.Cell):
|
||||
expansion = 1
|
||||
|
||||
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
||||
super(BasicBlock, self).__init__()
|
||||
|
||||
self.conv1 = ms_conv3x3(inplanes, planes, stride=stride, padding=1, pad_mode='pad')
|
||||
self.bn1 = ms_fused_bn(planes)
|
||||
self.conv2 = ms_conv3x3(planes, planes, stride=stride, padding=1, pad_mode='pad')
|
||||
self.bn2 = ms_fused_bn(planes)
|
||||
self.relu = P.ReLU()
|
||||
self.downsample = downsample
|
||||
self.add = P.TensorAdd()
|
||||
|
||||
def construct(self, x):
|
||||
residual = x
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv2(out)
|
||||
out = self.bn2(out)
|
||||
|
||||
if self.downsample is not None:
|
||||
residual = self.downsample(x)
|
||||
out = self.add(out, residual)
|
||||
out = self.relu(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
def weight_variable(shape, factor=0.1, half_precision=False):
|
||||
if half_precision:
|
||||
return initializer(TruncatedNormal(0.02), shape, dtype=mstype.float16)
|
||||
|
||||
return TruncatedNormal(0.02)
|
||||
|
||||
|
||||
def ms_conv3x3(in_channels, out_channels, stride=1, padding=0, pad_mode='same', has_bias=False):
|
||||
"""Get a conv2d layer with 3x3 kernel size."""
|
||||
init_value = weight_variable((out_channels, in_channels, 3, 3))
|
||||
return nn.Conv2d(in_channels, out_channels,
|
||||
kernel_size=3, stride=stride, padding=padding, pad_mode=pad_mode, weight_init=init_value,
|
||||
has_bias=has_bias)
|
||||
|
||||
|
||||
def ms_conv1x1(in_channels, out_channels, stride=1, padding=0, pad_mode='same', has_bias=False):
|
||||
"""Get a conv2d layer with 1x1 kernel size."""
|
||||
init_value = weight_variable((out_channels, in_channels, 1, 1))
|
||||
return nn.Conv2d(in_channels, out_channels,
|
||||
kernel_size=1, stride=stride, padding=padding, pad_mode=pad_mode, weight_init=init_value,
|
||||
has_bias=has_bias)
|
||||
|
||||
|
||||
def ms_conv2x2(in_channels, out_channels, stride=1, padding=0, pad_mode='same', has_bias=False):
|
||||
"""Get a conv2d layer with 2x2 kernel size."""
|
||||
init_value = weight_variable((out_channels, in_channels, 1, 1))
|
||||
return nn.Conv2d(in_channels, out_channels,
|
||||
kernel_size=2, stride=stride, padding=padding, pad_mode=pad_mode, weight_init=init_value,
|
||||
has_bias=has_bias)
|
||||
|
||||
|
||||
def ms_fused_bn(channels, momentum=0.1):
|
||||
"""Get a fused batchnorm"""
|
||||
return nn.BatchNorm2d(channels, momentum=momentum)
|
|
@ -0,0 +1,43 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""network config setting, will be used in train.py and eval.py"""
|
||||
|
||||
class Config_CNNCTC():
|
||||
# model config
|
||||
CHARACTER = '0123456789abcdefghijklmnopqrstuvwxyz'
|
||||
NUM_CLASS = len(CHARACTER) + 1
|
||||
HIDDEN_SIZE = 512
|
||||
FINAL_FEATURE_WIDTH = 26
|
||||
|
||||
# dataset config
|
||||
IMG_H = 32
|
||||
IMG_W = 100
|
||||
TRAIN_DATASET_PATH = 'CNNCTC_Data/ST_MJ/'
|
||||
TRAIN_DATASET_INDEX_PATH = 'CNNCTC_Data/st_mj_fixed_length_index_list.pkl'
|
||||
TRAIN_BATCH_SIZE = 192
|
||||
TEST_DATASET_PATH = 'CNNCTC_Data/IIIT5k_3000'
|
||||
TEST_BATCH_SIZE = 256
|
||||
TEST_DATASET_SIZE = 2976
|
||||
TRAIN_EPOCHS = 3
|
||||
|
||||
# training config
|
||||
CKPT_PATH = ''
|
||||
SAVE_PATH = './'
|
||||
LR = 1e-4
|
||||
LR_PARA = 5e-4
|
||||
MOMENTUM = 0.8
|
||||
LOSS_SCALE = 8096
|
||||
SAVE_CKPT_PER_N_STEP = 2000
|
||||
KEEP_CKPT_MAX_NUM = 5
|
|
@ -0,0 +1,265 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""cnn_ctc dataset"""
|
||||
|
||||
import sys
|
||||
import pickle
|
||||
import math
|
||||
import six
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
import lmdb
|
||||
|
||||
from mindspore.communication.management import get_rank, get_group_size
|
||||
|
||||
from .util import CTCLabelConverter
|
||||
from .config import Config_CNNCTC
|
||||
|
||||
config = Config_CNNCTC()
|
||||
|
||||
class NormalizePAD():
|
||||
|
||||
def __init__(self, max_size, PAD_type='right'):
|
||||
self.max_size = max_size
|
||||
self.PAD_type = PAD_type
|
||||
|
||||
def __call__(self, img):
|
||||
# toTensor
|
||||
img = np.array(img, dtype=np.float32)
|
||||
img = img.transpose([2, 0, 1])
|
||||
img = img.astype(np.float)
|
||||
img = np.true_divide(img, 255)
|
||||
# normalize
|
||||
img = np.subtract(img, 0.5)
|
||||
img = np.true_divide(img, 0.5)
|
||||
|
||||
_, _, w = img.shape
|
||||
Pad_img = np.zeros(shape=self.max_size, dtype=np.float32)
|
||||
Pad_img[:, :, :w] = img # right pad
|
||||
if self.max_size[2] != w: # add border Pad
|
||||
Pad_img[:, :, w:] = np.tile(np.expand_dims(img[:, :, w - 1], 2), (1, 1, self.max_size[2] - w))
|
||||
|
||||
return Pad_img
|
||||
|
||||
|
||||
class AlignCollate():
|
||||
|
||||
def __init__(self, imgH=32, imgW=100):
|
||||
self.imgH = imgH
|
||||
self.imgW = imgW
|
||||
|
||||
def __call__(self, images):
|
||||
|
||||
resized_max_w = self.imgW
|
||||
input_channel = 3
|
||||
transform = NormalizePAD((input_channel, self.imgH, resized_max_w))
|
||||
|
||||
resized_images = []
|
||||
for image in images:
|
||||
w, h = image.size
|
||||
ratio = w / float(h)
|
||||
if math.ceil(self.imgH * ratio) > self.imgW:
|
||||
resized_w = self.imgW
|
||||
else:
|
||||
resized_w = math.ceil(self.imgH * ratio)
|
||||
|
||||
resized_image = image.resize((resized_w, self.imgH), Image.BICUBIC)
|
||||
resized_images.append(transform(resized_image))
|
||||
|
||||
image_tensors = np.concatenate([np.expand_dims(t, 0) for t in resized_images], 0)
|
||||
|
||||
return image_tensors
|
||||
|
||||
|
||||
def get_img_from_lmdb(env, index):
|
||||
with env.begin(write=False) as txn:
|
||||
label_key = 'label-%09d'.encode() % index
|
||||
label = txn.get(label_key).decode('utf-8')
|
||||
img_key = 'image-%09d'.encode() % index
|
||||
imgbuf = txn.get(img_key)
|
||||
|
||||
buf = six.BytesIO()
|
||||
buf.write(imgbuf)
|
||||
buf.seek(0)
|
||||
try:
|
||||
img = Image.open(buf).convert('RGB') # for color image
|
||||
|
||||
except IOError:
|
||||
print(f'Corrupted image for {index}')
|
||||
# make dummy image and dummy label for corrupted image.
|
||||
img = Image.new('RGB', (config.IMG_W, config.IMG_H))
|
||||
label = '[dummy_label]'
|
||||
|
||||
label = label.lower()
|
||||
|
||||
return img, label
|
||||
|
||||
|
||||
class ST_MJ_Generator_batch_fixed_length:
|
||||
def __init__(self):
|
||||
self.align_collector = AlignCollate()
|
||||
self.converter = CTCLabelConverter(config.CHARACTER)
|
||||
self.env = lmdb.open(config.TRAIN_DATASET_PATH, max_readers=32, readonly=True, lock=False, readahead=False,
|
||||
meminit=False)
|
||||
if not self.env:
|
||||
print('cannot create lmdb from %s' % (config.TRAIN_DATASET_PATH))
|
||||
raise ValueError(config.TRAIN_DATASET_PATH)
|
||||
|
||||
with open(config.TRAIN_DATASET_INDEX_PATH, 'rb') as f:
|
||||
self.st_mj_filtered_index_list = pickle.load(f)
|
||||
|
||||
print(f'num of samples in ST_MJ dataset: {len(self.st_mj_filtered_index_list)}')
|
||||
self.dataset_size = len(self.st_mj_filtered_index_list) // config.TRAIN_BATCH_SIZE
|
||||
self.batch_size = config.TRAIN_BATCH_SIZE
|
||||
|
||||
def __len__(self):
|
||||
return self.dataset_size
|
||||
|
||||
def __getitem__(self, item):
|
||||
img_ret = []
|
||||
text_ret = []
|
||||
|
||||
for i in range(item * self.batch_size, (item + 1) * self.batch_size):
|
||||
index = self.st_mj_filtered_index_list[i]
|
||||
img, label = get_img_from_lmdb(self.env, index)
|
||||
|
||||
img_ret.append(img)
|
||||
text_ret.append(label)
|
||||
|
||||
img_ret = self.align_collector(img_ret)
|
||||
text_ret, length = self.converter.encode(text_ret)
|
||||
|
||||
label_indices = []
|
||||
for i, _ in enumerate(length):
|
||||
for j in range(length[i]):
|
||||
label_indices.append((i, j))
|
||||
label_indices = np.array(label_indices, np.int64)
|
||||
sequence_length = np.array([config.FINAL_FEATURE_WIDTH] * config.TRAIN_BATCH_SIZE, dtype=np.int32)
|
||||
text_ret = text_ret.astype(np.int32)
|
||||
|
||||
return img_ret, label_indices, text_ret, sequence_length
|
||||
|
||||
class ST_MJ_Generator_batch_fixed_length_para:
|
||||
def __init__(self):
|
||||
self.align_collector = AlignCollate()
|
||||
self.converter = CTCLabelConverter(config.CHARACTER)
|
||||
self.env = lmdb.open(config.TRAIN_DATASET_PATH, max_readers=32, readonly=True, lock=False, readahead=False,
|
||||
meminit=False)
|
||||
if not self.env:
|
||||
print('cannot create lmdb from %s' % (config.TRAIN_DATASET_PATH))
|
||||
raise ValueError(config.TRAIN_DATASET_PATH)
|
||||
|
||||
with open(config.TRAIN_DATASET_INDEX_PATH, 'rb') as f:
|
||||
self.st_mj_filtered_index_list = pickle.load(f)
|
||||
|
||||
print(f'num of samples in ST_MJ dataset: {len(self.st_mj_filtered_index_list)}')
|
||||
self.rank_id = get_rank()
|
||||
self.rank_size = get_group_size()
|
||||
self.dataset_size = len(self.st_mj_filtered_index_list) // config.TRAIN_BATCH_SIZE // self.rank_size
|
||||
self.batch_size = config.TRAIN_BATCH_SIZE
|
||||
|
||||
def __len__(self):
|
||||
return self.dataset_size
|
||||
|
||||
def __getitem__(self, item):
|
||||
img_ret = []
|
||||
text_ret = []
|
||||
|
||||
rank_item = (item * self.rank_size) + self.rank_id
|
||||
for i in range(rank_item * self.batch_size, (rank_item + 1) * self.batch_size):
|
||||
index = self.st_mj_filtered_index_list[i]
|
||||
img, label = get_img_from_lmdb(self.env, index)
|
||||
|
||||
img_ret.append(img)
|
||||
text_ret.append(label)
|
||||
|
||||
img_ret = self.align_collector(img_ret)
|
||||
text_ret, length = self.converter.encode(text_ret)
|
||||
|
||||
label_indices = []
|
||||
for i, _ in enumerate(length):
|
||||
for j in range(length[i]):
|
||||
label_indices.append((i, j))
|
||||
label_indices = np.array(label_indices, np.int64)
|
||||
sequence_length = np.array([config.FINAL_FEATURE_WIDTH] * config.TRAIN_BATCH_SIZE, dtype=np.int32)
|
||||
text_ret = text_ret.astype(np.int32)
|
||||
|
||||
return img_ret, label_indices, text_ret, sequence_length
|
||||
|
||||
|
||||
def IIIT_Generator_batch():
|
||||
max_len = int((26 + 1) // 2)
|
||||
|
||||
align_collector = AlignCollate()
|
||||
|
||||
converter = CTCLabelConverter(config.CHARACTER)
|
||||
|
||||
env = lmdb.open(config.TEST_DATASET_PATH, max_readers=32, readonly=True, lock=False, readahead=False, meminit=False)
|
||||
if not env:
|
||||
print('cannot create lmdb from %s' % (config.TEST_DATASET_PATH))
|
||||
sys.exit(0)
|
||||
|
||||
with env.begin(write=False) as txn:
|
||||
nSamples = int(txn.get('num-samples'.encode()))
|
||||
nSamples = nSamples
|
||||
|
||||
# Filtering
|
||||
filtered_index_list = []
|
||||
for index in range(nSamples):
|
||||
index += 1 # lmdb starts with 1
|
||||
label_key = 'label-%09d'.encode() % index
|
||||
label = txn.get(label_key).decode('utf-8')
|
||||
|
||||
if len(label) > max_len:
|
||||
continue
|
||||
|
||||
illegal_sample = False
|
||||
for char_item in label.lower():
|
||||
if char_item not in config.CHARACTER:
|
||||
illegal_sample = True
|
||||
break
|
||||
if illegal_sample:
|
||||
continue
|
||||
|
||||
filtered_index_list.append(index)
|
||||
|
||||
img_ret = []
|
||||
text_ret = []
|
||||
|
||||
print(f'num of samples in IIIT dataset: {len(filtered_index_list)}')
|
||||
|
||||
for index in filtered_index_list:
|
||||
|
||||
img, label = get_img_from_lmdb(env, index)
|
||||
|
||||
img_ret.append(img)
|
||||
text_ret.append(label)
|
||||
|
||||
if len(img_ret) == config.TEST_BATCH_SIZE:
|
||||
img_ret = align_collector(img_ret)
|
||||
text_ret, length = converter.encode(text_ret)
|
||||
|
||||
label_indices = []
|
||||
for i, _ in enumerate(length):
|
||||
for j in range(length[i]):
|
||||
label_indices.append((i, j))
|
||||
label_indices = np.array(label_indices, np.int64)
|
||||
sequence_length = np.array([26] * config.TEST_BATCH_SIZE, dtype=np.int32)
|
||||
text_ret = text_ret.astype(np.int32)
|
||||
|
||||
yield img_ret, label_indices, text_ret, sequence_length, length
|
||||
|
||||
img_ret = []
|
||||
text_ret = []
|
|
@ -0,0 +1,88 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""generate ascend rank file"""
|
||||
|
||||
import os
|
||||
import socket
|
||||
import argparse
|
||||
|
||||
parser = argparse.ArgumentParser(description="ascend distribute rank.")
|
||||
parser.add_argument("--rank_file", type=str, default="scripts/rank_table_8p.json", help="rank_tabel_file_path.")
|
||||
|
||||
def main(rank_table_file):
|
||||
nproc_per_node = 8
|
||||
|
||||
visible_devices = ['0', '1', '2', '3', '4', '5', '6', '7']
|
||||
|
||||
server_id = socket.gethostbyname(socket.gethostname())
|
||||
|
||||
hccn_configs = open('/etc/hccn.conf', 'r').readlines()
|
||||
device_ips = {}
|
||||
for hccn_item in hccn_configs:
|
||||
hccn_item = hccn_item.strip()
|
||||
if hccn_item.startswith('address_'):
|
||||
device_id, device_ip = hccn_item.split('=')
|
||||
device_id = device_id.split('_')[1]
|
||||
device_ips[device_id] = device_ip
|
||||
print('device_id:{}, device_ip:{}'.format(device_id, device_ip))
|
||||
|
||||
hccn_table = {}
|
||||
hccn_table['board_id'] = '0x002f' # A+K
|
||||
# hccn_table['board_id'] = '0x0000' # A+X
|
||||
|
||||
hccn_table['chip_info'] = '910'
|
||||
hccn_table['deploy_mode'] = 'lab'
|
||||
hccn_table['group_count'] = '1'
|
||||
hccn_table['group_list'] = []
|
||||
instance_list = []
|
||||
for instance_id in range(nproc_per_node):
|
||||
instance = {}
|
||||
instance['devices'] = []
|
||||
device_id = visible_devices[instance_id]
|
||||
device_ip = device_ips[device_id]
|
||||
instance['devices'].append({
|
||||
'device_id': device_id,
|
||||
'device_ip': device_ip,
|
||||
})
|
||||
instance['rank_id'] = str(instance_id)
|
||||
instance['server_id'] = server_id
|
||||
instance_list.append(instance)
|
||||
hccn_table['group_list'].append({
|
||||
'device_num': str(nproc_per_node),
|
||||
'server_num': '1',
|
||||
'group_name': '',
|
||||
'instance_count': str(nproc_per_node),
|
||||
'instance_list': instance_list,
|
||||
})
|
||||
hccn_table['para_plane_nic_location'] = 'device'
|
||||
hccn_table['para_plane_nic_name'] = []
|
||||
for instance_id in range(nproc_per_node):
|
||||
eth_id = visible_devices[instance_id]
|
||||
hccn_table['para_plane_nic_name'].append('eth{}'.format(eth_id))
|
||||
hccn_table['para_plane_nic_num'] = str(nproc_per_node)
|
||||
hccn_table['status'] = 'completed'
|
||||
import json
|
||||
with open(rank_table_file, 'w') as table_fp:
|
||||
json.dump(hccn_table, table_fp, indent=4)
|
||||
|
||||
if __name__ == '__main__':
|
||||
args_opt = parser.parse_args()
|
||||
rank_table = args_opt.rank_file
|
||||
if os.path.exists(rank_table):
|
||||
print('Rank table file exists.')
|
||||
else:
|
||||
print('Generating rank table file.')
|
||||
main(rank_table)
|
||||
print('Rank table file generated')
|
|
@ -0,0 +1,171 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""preprocess dataset"""
|
||||
|
||||
import random
|
||||
import pickle
|
||||
import numpy as np
|
||||
import lmdb
|
||||
from tqdm import tqdm
|
||||
|
||||
def combine_lmdbs(lmdb_paths, lmdb_save_path):
|
||||
max_len = int((26 + 1) // 2)
|
||||
character = '0123456789abcdefghijklmnopqrstuvwxyz'
|
||||
|
||||
env_save = lmdb.open(
|
||||
lmdb_save_path,
|
||||
map_size=1099511627776)
|
||||
|
||||
cnt = 0
|
||||
for lmdb_path in lmdb_paths:
|
||||
env = lmdb.open(lmdb_path, max_readers=32, readonly=True, lock=False, readahead=False, meminit=False)
|
||||
with env.begin(write=False) as txn:
|
||||
nSamples = int(txn.get('num-samples'.encode()))
|
||||
nSamples = nSamples
|
||||
|
||||
# Filtering
|
||||
for index in tqdm(range(nSamples)):
|
||||
index += 1 # lmdb starts with 1
|
||||
label_key = 'label-%09d'.encode() % index
|
||||
label = txn.get(label_key).decode('utf-8')
|
||||
|
||||
if len(label) > max_len:
|
||||
continue
|
||||
|
||||
illegal_sample = False
|
||||
for char_item in label.lower():
|
||||
if char_item not in character:
|
||||
illegal_sample = True
|
||||
break
|
||||
if illegal_sample:
|
||||
continue
|
||||
|
||||
img_key = 'image-%09d'.encode() % index
|
||||
imgbuf = txn.get(img_key)
|
||||
|
||||
with env_save.begin(write=True) as txn_save:
|
||||
cnt += 1
|
||||
|
||||
label_key_save = 'label-%09d'.encode() % cnt
|
||||
label_save = label.encode()
|
||||
image_key_save = 'image-%09d'.encode() % cnt
|
||||
image_save = imgbuf
|
||||
|
||||
txn_save.put(label_key_save, label_save)
|
||||
txn_save.put(image_key_save, image_save)
|
||||
|
||||
nSamples = cnt
|
||||
with env_save.begin(write=True) as txn_save:
|
||||
txn_save.put('num-samples'.encode(), str(nSamples).encode())
|
||||
|
||||
|
||||
def analyze_lmdb_label_length(lmdb_path, batch_size=192, num_of_combinations=1000):
|
||||
label_length_dict = {}
|
||||
|
||||
env = lmdb.open(lmdb_path, max_readers=32, readonly=True, lock=False, readahead=False, meminit=False)
|
||||
with env.begin(write=False) as txn:
|
||||
nSamples = int(txn.get('num-samples'.encode()))
|
||||
nSamples = nSamples
|
||||
|
||||
for index in tqdm(range(nSamples)):
|
||||
index += 1 # lmdb starts with 1
|
||||
label_key = 'label-%09d'.encode() % index
|
||||
label = txn.get(label_key).decode('utf-8')
|
||||
|
||||
label_length = len(label)
|
||||
if label_length in label_length_dict:
|
||||
label_length_dict[label_length] += 1
|
||||
else:
|
||||
label_length_dict[label_length] = 1
|
||||
|
||||
sorted_label_length = sorted(label_length_dict.items(), key=lambda x: x[1], reverse=True)
|
||||
|
||||
label_length_sum = 0
|
||||
label_num = 0
|
||||
lengths = []
|
||||
p = []
|
||||
for l, num in sorted_label_length:
|
||||
label_length_sum += l * num
|
||||
label_num += num
|
||||
p.append(num)
|
||||
lengths.append(l)
|
||||
for i, _ in enumerate(p):
|
||||
p[i] /= label_num
|
||||
|
||||
average_overall_length = int(label_length_sum / label_num * batch_size)
|
||||
|
||||
def get_combinations_of_fix_length(fix_length, items, p, batch_size):
|
||||
ret = []
|
||||
cur_sum = 0
|
||||
ret = np.random.choice(items, batch_size - 1, True, p)
|
||||
cur_sum = sum(ret)
|
||||
ret = list(ret)
|
||||
if fix_length - cur_sum in items:
|
||||
ret.append(fix_length - cur_sum)
|
||||
else:
|
||||
return None
|
||||
return ret
|
||||
|
||||
result = []
|
||||
while len(result) < num_of_combinations:
|
||||
ret = get_combinations_of_fix_length(average_overall_length, lengths, p, batch_size)
|
||||
if ret is not None:
|
||||
result.append(ret)
|
||||
return result
|
||||
|
||||
|
||||
def generate_fix_shape_index_list(lmdb_path, combinations, pkl_save_path, num_of_iters=70000):
|
||||
length_index_dict = {}
|
||||
|
||||
env = lmdb.open(lmdb_path, max_readers=32, readonly=True, lock=False, readahead=False, meminit=False)
|
||||
with env.begin(write=False) as txn:
|
||||
nSamples = int(txn.get('num-samples'.encode()))
|
||||
nSamples = nSamples
|
||||
|
||||
for index in tqdm(range(nSamples)):
|
||||
index += 1 # lmdb starts with 1
|
||||
label_key = 'label-%09d'.encode() % index
|
||||
label = txn.get(label_key).decode('utf-8')
|
||||
|
||||
label_length = len(label)
|
||||
if label_length in length_index_dict:
|
||||
length_index_dict[label_length].append(index)
|
||||
else:
|
||||
length_index_dict[label_length] = [index]
|
||||
|
||||
ret = []
|
||||
for _ in range(num_of_iters):
|
||||
comb = random.choice(combinations)
|
||||
for l in comb:
|
||||
ret.append(random.choice(length_index_dict[l]))
|
||||
|
||||
with open(pkl_save_path, 'wb') as f:
|
||||
pickle.dump(ret, f, -1)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# step 1: combine the SynthText dataset and MJSynth dataset into a single lmdb file
|
||||
print('Begin to combine multiple lmdb datasets')
|
||||
combine_lmdbs(['/home/workspace/mindspore_dataset/CNNCTC_Data/1_ST/',
|
||||
'/home/workspace/mindspore_dataset/CNNCTC_Data/MJ_train/'],
|
||||
'/home/workspace/mindspore_dataset/CNNCTC_Data/ST_MJ')
|
||||
|
||||
# step 2: generate the order of input data, guarantee that the input batch shape is fixed
|
||||
print('Begin to generate the index order of input data')
|
||||
combination = analyze_lmdb_label_length('/home/workspace/mindspore_dataset/CNNCTC_Data/ST_MJ')
|
||||
generate_fix_shape_index_list('/home/workspace/mindspore_dataset/CNNCTC_Data/ST_MJ', combination,
|
||||
'/home/workspace/mindspore_dataset/CNNCTC_Data/st_mj_fixed_length_index_list.pkl')
|
||||
|
||||
print('Done')
|
|
@ -0,0 +1,102 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""util file"""
|
||||
|
||||
import numpy as np
|
||||
|
||||
class AverageMeter():
|
||||
"""Computes and stores the average and current value"""
|
||||
|
||||
def __init__(self):
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
self.val = 0
|
||||
self.avg = 0
|
||||
self.sum = 0
|
||||
self.count = 0
|
||||
|
||||
def update(self, val, n=1):
|
||||
self.val = val
|
||||
self.sum += val * n
|
||||
self.count += n
|
||||
self.avg = self.sum / self.count
|
||||
|
||||
|
||||
class CTCLabelConverter():
|
||||
""" Convert between text-label and text-index """
|
||||
|
||||
def __init__(self, character):
|
||||
# character (str): set of the possible characters.
|
||||
dict_character = list(character)
|
||||
|
||||
self.dict = {}
|
||||
for i, char in enumerate(dict_character):
|
||||
self.dict[char] = i
|
||||
|
||||
self.character = dict_character + ['[blank]'] # dummy '[blank]' token for CTCLoss (index 0)
|
||||
self.dict['[blank]'] = len(dict_character)
|
||||
|
||||
def encode(self, text):
|
||||
"""convert text-label into text-index.
|
||||
input:
|
||||
text: text labels of each image. [batch_size]
|
||||
|
||||
output:
|
||||
text: concatenated text index for CTCLoss.
|
||||
[sum(text_lengths)] = [text_index_0 + text_index_1 + ... + text_index_(n - 1)]
|
||||
length: length of each text. [batch_size]
|
||||
"""
|
||||
length = [len(s) for s in text]
|
||||
text = ''.join(text)
|
||||
text = [self.dict[char] for char in text]
|
||||
|
||||
return np.array(text), np.array(length)
|
||||
|
||||
def decode(self, text_index, length):
|
||||
""" convert text-index into text-label. """
|
||||
texts = []
|
||||
index = 0
|
||||
for l in length:
|
||||
t = text_index[index:index + l]
|
||||
|
||||
char_list = []
|
||||
for i in range(l):
|
||||
# if t[i] != self.dict['[blank]'] and (not (i > 0 and t[i - 1] == t[i])): # removing repeated characters and blank.
|
||||
if t[i] != self.dict['[blank]'] and (
|
||||
not (i > 0 and t[i - 1] == t[i])): # removing repeated characters and blank.
|
||||
char_list.append(self.character[t[i]])
|
||||
text = ''.join(char_list)
|
||||
|
||||
texts.append(text)
|
||||
index += l
|
||||
return texts
|
||||
|
||||
def reverse_encode(self, text_index, length):
|
||||
""" convert text-index into text-label. """
|
||||
texts = []
|
||||
index = 0
|
||||
for l in length:
|
||||
t = text_index[index:index + l]
|
||||
|
||||
char_list = []
|
||||
for i in range(l):
|
||||
if t[i] != self.dict['[blank]']: # removing repeated characters and blank.
|
||||
char_list.append(self.character[t[i]])
|
||||
text = ''.join(char_list)
|
||||
|
||||
texts.append(text)
|
||||
index += l
|
||||
return texts
|
|
@ -0,0 +1,100 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""cnnctc train"""
|
||||
|
||||
import argparse
|
||||
import ast
|
||||
|
||||
import mindspore
|
||||
from mindspore import context
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from mindspore.dataset import GeneratorDataset
|
||||
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig
|
||||
from mindspore.train.model import Model
|
||||
from mindspore.communication.management import init
|
||||
from mindspore.common import set_seed
|
||||
|
||||
from src.config import Config_CNNCTC
|
||||
from src.callback import LossCallBack
|
||||
from src.dataset import ST_MJ_Generator_batch_fixed_length, ST_MJ_Generator_batch_fixed_length_para
|
||||
from src.cnn_ctc import CNNCTC_Model, ctc_loss, WithLossCell
|
||||
|
||||
set_seed(1)
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False,
|
||||
save_graphs_path=".", enable_auto_mixed_precision=False)
|
||||
|
||||
|
||||
def dataset_creator(run_distribute):
|
||||
if run_distribute:
|
||||
st_dataset = ST_MJ_Generator_batch_fixed_length_para()
|
||||
else:
|
||||
st_dataset = ST_MJ_Generator_batch_fixed_length()
|
||||
|
||||
ds = GeneratorDataset(st_dataset,
|
||||
['img', 'label_indices', 'text', 'sequence_length'],
|
||||
num_parallel_workers=8)
|
||||
|
||||
return ds
|
||||
|
||||
|
||||
def train(args_opt, config):
|
||||
if args_opt.run_distribute:
|
||||
init()
|
||||
context.set_auto_parallel_context(parallel_mode="data_parallel")
|
||||
|
||||
ds = dataset_creator(args_opt.run_distribute)
|
||||
|
||||
net = CNNCTC_Model(config.NUM_CLASS, config.HIDDEN_SIZE, config.FINAL_FEATURE_WIDTH)
|
||||
net.set_train(True)
|
||||
|
||||
if config.CKPT_PATH != '':
|
||||
param_dict = load_checkpoint(config.CKPT_PATH)
|
||||
load_param_into_net(net, param_dict)
|
||||
print('parameters loaded!')
|
||||
else:
|
||||
print('train from scratch...')
|
||||
|
||||
criterion = ctc_loss()
|
||||
opt = mindspore.nn.RMSProp(params=net.trainable_params(), centered=True, learning_rate=config.LR_PARA,
|
||||
momentum=config.MOMENTUM, loss_scale=config.LOSS_SCALE)
|
||||
|
||||
net = WithLossCell(net, criterion)
|
||||
loss_scale_manager = mindspore.train.loss_scale_manager.FixedLossScaleManager(config.LOSS_SCALE, False)
|
||||
model = Model(net, optimizer=opt, loss_scale_manager=loss_scale_manager, amp_level="O2")
|
||||
|
||||
callback = LossCallBack()
|
||||
config_ck = CheckpointConfig(save_checkpoint_steps=config.SAVE_CKPT_PER_N_STEP,
|
||||
keep_checkpoint_max=config.KEEP_CKPT_MAX_NUM)
|
||||
ckpoint_cb = ModelCheckpoint(prefix="CNNCTC", config=config_ck, directory=config.SAVE_PATH)
|
||||
|
||||
if args_opt.device_id == 0:
|
||||
model.train(config.TRAIN_EPOCHS, ds, callbacks=[callback, ckpoint_cb], dataset_sink_mode=False)
|
||||
else:
|
||||
model.train(config.TRAIN_EPOCHS, ds, callbacks=[callback], dataset_sink_mode=False)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description='CNNCTC arg')
|
||||
parser.add_argument('--device_id', type=int, default=0, help="Device id, default is 0.")
|
||||
parser.add_argument("--ckpt_path", type=str, default="", help="Pretrain file path.")
|
||||
parser.add_argument("--run_distribute", type=ast.literal_eval, default=False,
|
||||
help="Run distribute, default is false.")
|
||||
args_cfg = parser.parse_args()
|
||||
|
||||
cfg = Config_CNNCTC()
|
||||
if args_cfg.ckpt_path != "":
|
||||
cfg.CKPT_PATH = args_cfg.ckpt_path
|
||||
train(args_cfg, cfg)
|
|
@ -11,7 +11,7 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#" :===========================================================================
|
||||
# ===========================================================================
|
||||
"""
|
||||
network config setting, will be used in train.py and eval.py
|
||||
"""
|
||||
|
|
Loading…
Reference in New Issue