forked from mindspore-Ecosystem/mindspore
!6383 Add modelzoo CNNCTC Network.
Merge pull request !6383 from linqingke/fasterrcnn
This commit is contained in:
commit
737e27d721
|
@ -0,0 +1,354 @@
|
||||||
|
# Contents
|
||||||
|
|
||||||
|
- [CNNCTC Description](#CNNCTC-description)
|
||||||
|
- [Model Architecture](#model-architecture)
|
||||||
|
- [Dataset](#dataset)
|
||||||
|
- [Features](#features)
|
||||||
|
- [Mixed Precision](#mixed-precision)
|
||||||
|
- [Environment Requirements](#environment-requirements)
|
||||||
|
- [Quick Start](#quick-start)
|
||||||
|
- [Script Description](#script-description)
|
||||||
|
- [Script and Sample Code](#script-and-sample-code)
|
||||||
|
- [Script Parameters](#script-parameters)
|
||||||
|
- [Training Process](#training-process)
|
||||||
|
- [Training](#training)
|
||||||
|
- [Distributed Training](#distributed-training)
|
||||||
|
- [Evaluation Process](#evaluation-process)
|
||||||
|
- [Evaluation](#evaluation)
|
||||||
|
- [Model Description](#model-description)
|
||||||
|
- [Performance](#performance)
|
||||||
|
- [Evaluation Performance](#evaluation-performance)
|
||||||
|
- [Inference Performance](#evaluation-performance)
|
||||||
|
- [How to use](#how-to-use)
|
||||||
|
- [Inference](#inference)
|
||||||
|
- [Continue Training on the Pretrained Model](#continue-training-on-the-pretrained-model)
|
||||||
|
- [Transfer Learning](#transfer-learning)
|
||||||
|
- [Description of Random Situation](#description-of-random-situation)
|
||||||
|
- [ModelZoo Homepage](#modelzoo-homepage)
|
||||||
|
|
||||||
|
|
||||||
|
# [CNNCTC Description](#contents)
|
||||||
|
This paper proposes three major contributions to addresses scene text recognition (STR).
|
||||||
|
First, we examine the inconsistencies of training and evaluation datasets, and the performance gap results from inconsistencies.
|
||||||
|
Second, we introduce a unified four-stage STR framework that most existing STR models fit into.
|
||||||
|
Using this framework allows for the extensive evaluation of previously proposed STR modules and the discovery of previously
|
||||||
|
unexplored module combinations. Third, we analyze the module-wise contributions to performance in terms of accuracy, speed,
|
||||||
|
and memory demand, under one consistent set of training and evaluation datasets. Such analyses clean up the hindrance on the current
|
||||||
|
comparisons to understand the performance gain of the existing modules.
|
||||||
|
[Paper](https://arxiv.org/abs/1904.01906): J. Baek, G. Kim, J. Lee, S. Park, D. Han, S. Yun, S. J. Oh, and H. Lee, “What is wrong with scene text recognition model comparisons? dataset and model analysis,” ArXiv, vol. abs/1904.01906, 2019.
|
||||||
|
|
||||||
|
# [Model Architecture](#contents)
|
||||||
|
This is an example of training CNN+CTC model for text recognition on MJSynth and SynthText dataset with MindSpore.
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# [Dataset](#contents)
|
||||||
|
|
||||||
|
The [MJSynth](https://www.robots.ox.ac.uk/~vgg/data/text/) and [SynthText](https://github.com/ankush-me/SynthText) dataset are used for model training. The [The IIIT 5K-word dataset](https://cvit.iiit.ac.in/research/projects/cvit-projects/the-iiit-5k-word-dataset) dataset is used for evaluation.
|
||||||
|
|
||||||
|
- step 1:
|
||||||
|
All the datasets have been preprocessed and stored in .lmdb format and can be downloaded [**HERE**](https://drive.google.com/drive/folders/192UfE9agQUMNq6AgU3_E05_FcPZK4hyt).
|
||||||
|
|
||||||
|
- step 2:
|
||||||
|
Uncompress the downloaded file, rename the MJSynth dataset as MJ, the SynthText dataset as ST and the IIIT dataset as IIIT.
|
||||||
|
|
||||||
|
- step 3:
|
||||||
|
Move above mentioned three datasets into `cnnctc_data` folder, and the structure should be as below:
|
||||||
|
```
|
||||||
|
|--- CNNCTC/
|
||||||
|
|--- cnnctc_data/
|
||||||
|
|--- ST/
|
||||||
|
data.mdb
|
||||||
|
lock.mdb
|
||||||
|
|--- MJ/
|
||||||
|
data.mdb
|
||||||
|
lock.mdb
|
||||||
|
|--- IIIT/
|
||||||
|
data.mdb
|
||||||
|
lock.mdb
|
||||||
|
|
||||||
|
......
|
||||||
|
```
|
||||||
|
|
||||||
|
- step 4:
|
||||||
|
Preprocess the dataset by running:
|
||||||
|
```
|
||||||
|
python src/preprocess_dataset.py
|
||||||
|
```
|
||||||
|
|
||||||
|
This takes around 75 minutes.
|
||||||
|
|
||||||
|
# [Features](#contents)
|
||||||
|
|
||||||
|
## Mixed Precision
|
||||||
|
|
||||||
|
The [mixed precision](https://www.mindspore.cn/tutorial/zh-CN/master/advanced_use/mixed_precision.html) training method accelerates the deep learning neural network training process by using both the single-precision and half-precision data formats, and maintains the network precision achieved by the single-precision training at the same time. Mixed precision training can accelerate the computation process, reduce memory usage, and enable a larger model or batch size to be trained on specific hardware.
|
||||||
|
For FP16 operators, if the input data type is FP32, the backend of MindSpore will automatically handle it with reduced precision. Users could check the reduced-precision operators by enabling INFO log and then searching ‘reduce precision’.
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# [Environment Requirements](#contents)
|
||||||
|
|
||||||
|
- Hardware(Ascend)
|
||||||
|
|
||||||
|
- Prepare hardware environment with Ascend processor. If you want to try Ascend , please send the [application form](https://obs-9be7.obs.cn-east-2.myhuaweicloud.com/file/other/Ascend%20Model%20Zoo%E4%BD%93%E9%AA%8C%E8%B5%84%E6%BA%90%E7%94%B3%E8%AF%B7%E8%A1%A8.docx) to ascend@huawei.com. Once approved, you can get the resources.
|
||||||
|
- Framework
|
||||||
|
|
||||||
|
- [MindSpore](https://www.mindspore.cn/install/en)
|
||||||
|
- For more information, please check the resources below:
|
||||||
|
- [MindSpore tutorials](https://www.mindspore.cn/tutorial/zh-CN/master/index.html)
|
||||||
|
|
||||||
|
- [MindSpore API](https://www.mindspore.cn/api/zh-CN/master/index.html)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# [Quick Start](#contents)
|
||||||
|
|
||||||
|
- Install dependencies:
|
||||||
|
```
|
||||||
|
pip install lmdb
|
||||||
|
pip install Pillow
|
||||||
|
pip install tqdm
|
||||||
|
pip install six
|
||||||
|
```
|
||||||
|
|
||||||
|
- Standalone Training:
|
||||||
|
```
|
||||||
|
bash scripts/run_standalone_train_ascend.sh $PRETRAINED_CKPT
|
||||||
|
```
|
||||||
|
|
||||||
|
- Distributed Training:
|
||||||
|
```
|
||||||
|
bash scripts/run_distribute_train_ascend.sh $RANK_TABLE_FILE $PRETRAINED_CKPT
|
||||||
|
```
|
||||||
|
|
||||||
|
- Evaluation:
|
||||||
|
```
|
||||||
|
bash scripts/run_eval_ascend.sh $TRAINED_CKPT
|
||||||
|
```
|
||||||
|
|
||||||
|
# [Script Description](#contents)
|
||||||
|
|
||||||
|
## [Script and Sample Code](#contents)
|
||||||
|
The entire code structure is as following:
|
||||||
|
```
|
||||||
|
|--- CNNCTC/
|
||||||
|
|---README.md // descriptions about cnnctc
|
||||||
|
|---train.py // train scripts
|
||||||
|
|---eval.py // eval scripts
|
||||||
|
|---scripts
|
||||||
|
|---run_standalone_train_ascend.sh // shell script for standalone on ascend
|
||||||
|
|---run_distribute_train_ascend.sh // shell script for distributed on ascend
|
||||||
|
|---run_eval_ascend.sh // shell script for eval on ascend
|
||||||
|
|---src
|
||||||
|
|---__init__.py // init file
|
||||||
|
|---cnn_ctc.py // cnn_ctc network
|
||||||
|
|---config.py // total config
|
||||||
|
|---callback.py // loss callback file
|
||||||
|
|---dataset.py // process dataset
|
||||||
|
|---util.py // routine operation
|
||||||
|
|---generate_hccn_file.py // generate distribute json file
|
||||||
|
|---preprocess_dataset.py // preprocess dataset
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
## [Script Parameters](#contents)
|
||||||
|
Parameters for both training and evaluation can be set in `config.py`.
|
||||||
|
|
||||||
|
Arguments:
|
||||||
|
* `--CHARACTER`: Character labels.
|
||||||
|
* `--NUM_CLASS`: The number of classes including all character labels and the <blank> label for CTCLoss.
|
||||||
|
* `--HIDDEN_SIZE`: Model hidden size.
|
||||||
|
* `--FINAL_FEATURE_WIDTH`: The number of features.
|
||||||
|
* `--IMG_H`: The height of input image.
|
||||||
|
* `--IMG_W`: The width of input image.
|
||||||
|
* `--TRAIN_DATASET_PATH`: The path to training dataset.
|
||||||
|
* `--TRAIN_DATASET_INDEX_PATH`: The path to training dataset index file which determines the order .
|
||||||
|
* `--TRAIN_BATCH_SIZE`: Training batch size. The batch size and index file must ensure input data is in fixed shape.
|
||||||
|
* `--TRAIN_DATASET_SIZE`: Training dataset size.
|
||||||
|
* `--TEST_DATASET_PATH`: The path to test dataset.
|
||||||
|
* `--TEST_BATCH_SIZE`: Test batch size.
|
||||||
|
* `--TEST_DATASET_SIZE`:Test dataset size.
|
||||||
|
* `--TRAIN_EPOCHS`:Total training epochs.
|
||||||
|
* `--CKPT_PATH`:The path to model checkpoint file, can be used to resume training and evaluation.
|
||||||
|
* `--SAVE_PATH`:The path to save model checkpoint file.
|
||||||
|
* `--LR`:Learning rate for standalone training.
|
||||||
|
* `--LR_PARA`:Learning rate for distributed training.
|
||||||
|
* `--MOMENTUM`:Momentum.
|
||||||
|
* `--LOSS_SCALE`:Loss scale to prevent gradient underflow.
|
||||||
|
* `--SAVE_CKPT_PER_N_STEP`:Save model checkpoint file per N steps.
|
||||||
|
* `--KEEP_CKPT_MAX_NUM`:The maximum number of saved model checkpoint file.
|
||||||
|
|
||||||
|
## [Training Process](#contents)
|
||||||
|
|
||||||
|
### Training
|
||||||
|
|
||||||
|
- Standalone Training:
|
||||||
|
```
|
||||||
|
bash scripts/run_standalone_train_ascend.sh $PRETRAINED_CKPT
|
||||||
|
```
|
||||||
|
|
||||||
|
Results and checkpoints are written to `./train` folder. Log can be found in `./train/log` and loss values are recorded in `./train/loss.log`.
|
||||||
|
|
||||||
|
`$PRETRAINED_CKPT` is the path to model checkpoint and it is **optional**. If none is given the model will be trained from scratch.
|
||||||
|
|
||||||
|
- Distributed Training:
|
||||||
|
```
|
||||||
|
bash scripts/run_distribute_train_ascend.sh $RANK_TABLE_FILE $PRETRAINED_CKPT
|
||||||
|
```
|
||||||
|
|
||||||
|
Results and checkpoints are written to `./train_parallel_{i}` folder for device `i` respectively.
|
||||||
|
Log can be found in `./train_parallel_{i}/log_{i}.log` and loss values are recorded in `./train_parallel_{i}/loss.log`.
|
||||||
|
|
||||||
|
`$RANK_TABLE_FILE` is needed when you are running a distribute task on ascend.
|
||||||
|
`$PATH_TO_CHECKPOINT` is the path to model checkpoint and it is **optional**. If none is given the model will be trained from scratch.
|
||||||
|
|
||||||
|
### Training Result
|
||||||
|
|
||||||
|
Training result will be stored in the example path, whose folder name begins with "train" or "train_parallel". You can find checkpoint file together with result like the followings in loss.log.
|
||||||
|
|
||||||
|
|
||||||
|
```
|
||||||
|
# distribute training result(8p)
|
||||||
|
epoch: 1 step: 1 , loss is 76.25, average time per step is 0.335177839748392712
|
||||||
|
epoch: 1 step: 2 , loss is 73.46875, average time per step is 0.36798572540283203
|
||||||
|
epoch: 1 step: 3 , loss is 69.46875, average time per step is 0.3429678678512573
|
||||||
|
epoch: 1 step: 4 , loss is 64.3125, average time per step is 0.33512671788533527
|
||||||
|
epoch: 1 step: 5 , loss is 58.375, average time per step is 0.33149147033691406
|
||||||
|
epoch: 1 step: 6 , loss is 52.7265625, average time per step is 0.3292975425720215
|
||||||
|
...
|
||||||
|
epoch: 1 step: 8689 , loss is 9.706798802612482, average time per step is 0.3184656601312549
|
||||||
|
epoch: 1 step: 8690 , loss is 9.70612545289855, average time per step is 0.3184725407765116
|
||||||
|
epoch: 1 step: 8691 , loss is 9.70695776049204, average time per step is 0.31847309686135555
|
||||||
|
epoch: 1 step: 8692 , loss is 9.707279624277456, average time per step is 0.31847339290613375
|
||||||
|
epoch: 1 step: 8693 , loss is 9.70763437950938, average time per step is 0.3184720295013031
|
||||||
|
epoch: 1 step: 8694 , loss is 9.707695425072046, average time per step is 0.31847410284595573
|
||||||
|
epoch: 1 step: 8695 , loss is 9.708408273381295, average time per step is 0.31847338271072345
|
||||||
|
epoch: 1 step: 8696 , loss is 9.708703753591953, average time per step is 0.3184726025560777
|
||||||
|
epoch: 1 step: 8697 , loss is 9.709536406025824, average time per step is 0.31847212061114694
|
||||||
|
epoch: 1 step: 8698 , loss is 9.708542263610315, average time per step is 0.3184715309307257
|
||||||
|
```
|
||||||
|
|
||||||
|
## [Evaluation Process](#contents)
|
||||||
|
|
||||||
|
### Evaluation
|
||||||
|
- Evaluation:
|
||||||
|
```
|
||||||
|
bash scripts/run_eval_ascend.sh $TRAINED_CKPT
|
||||||
|
```
|
||||||
|
|
||||||
|
The model will be evaluated on the IIIT dataset, sample results and overall accuracy will be printed.
|
||||||
|
|
||||||
|
|
||||||
|
# [Model Description](#contents)
|
||||||
|
## [Performance](#contents)
|
||||||
|
|
||||||
|
### Training Performance
|
||||||
|
|
||||||
|
| Parameters | FasterRcnn |
|
||||||
|
| -------------------------- | ----------------------------------------------------------- |
|
||||||
|
| Model Version | V1 |
|
||||||
|
| Resource | Ascend 910 ;CPU 2.60GHz,192cores;Memory,755G |
|
||||||
|
| uploaded Date | 09/28/2020 (month/day/year) |
|
||||||
|
| MindSpore Version | 1.0.0 |
|
||||||
|
| Dataset | MJSynth,SynthText |
|
||||||
|
| Training Parameters | epoch=3, batch_size=192 |
|
||||||
|
| Optimizer | RMSProp |
|
||||||
|
| Loss Function | CTCLoss |
|
||||||
|
| Speed | 1pc: 300 ms/step; 8pcs: 310 ms/step |
|
||||||
|
| Total time | 1pc: 18 hours; 8pcs: 2.3 hours |
|
||||||
|
| Parameters (M) | 177 |
|
||||||
|
| Scripts | https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/cnnctc |
|
||||||
|
|
||||||
|
|
||||||
|
### Evaluation Performance
|
||||||
|
|
||||||
|
| Parameters | FasterRcnn |
|
||||||
|
| ------------------- | --------------------------- |
|
||||||
|
| Model Version | V1 |
|
||||||
|
| Resource | Ascend 910 |
|
||||||
|
| Uploaded Date | 09/28/2020 (month/day/year) |
|
||||||
|
| MindSpore Version | 1.0.0 |
|
||||||
|
| Dataset | IIIT5K |
|
||||||
|
| batch_size | 192 |
|
||||||
|
| outputs | Accuracy |
|
||||||
|
| Accuracy | 85% |
|
||||||
|
| Model for inference | 675M (.ckpt file) |
|
||||||
|
|
||||||
|
## [How to use](#contents)
|
||||||
|
### Inference
|
||||||
|
|
||||||
|
If you need to use the trained model to perform inference on multiple hardware platforms, such as GPU, Ascend 910 or Ascend 310, you can refer to this [Link](https://www.mindspore.cn/tutorial/zh-CN/master/advanced_use/network_migration.html). Following the steps below, this is a simple example:
|
||||||
|
|
||||||
|
- Running on Ascend
|
||||||
|
|
||||||
|
```
|
||||||
|
# Set context
|
||||||
|
context.set_context(mode=context.GRAPH_HOME, device_target=cfg.device_target)
|
||||||
|
context.set_context(device_id=cfg.device_id)
|
||||||
|
|
||||||
|
# Load unseen dataset for inference
|
||||||
|
dataset = dataset.create_dataset(cfg.data_path, 1, False)
|
||||||
|
|
||||||
|
# Define model
|
||||||
|
net = CNNCTC(cfg.NUM_CLASS, cfg.HIDDEN_SIZE, cfg.FINAL_FEATURE_WIDTH)
|
||||||
|
opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 0.01,
|
||||||
|
cfg.momentum, weight_decay=cfg.weight_decay)
|
||||||
|
loss = P.CTCLoss(preprocess_collapse_repeated=False,
|
||||||
|
ctc_merge_repeated=True,
|
||||||
|
ignore_longer_outputs_than_inputs=False)
|
||||||
|
model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'})
|
||||||
|
|
||||||
|
# Load pre-trained model
|
||||||
|
param_dict = load_checkpoint(cfg.checkpoint_path)
|
||||||
|
load_param_into_net(net, param_dict)
|
||||||
|
net.set_train(False)
|
||||||
|
|
||||||
|
# Make predictions on the unseen dataset
|
||||||
|
acc = model.eval(dataset)
|
||||||
|
print("accuracy: ", acc)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Continue Training on the Pretrained Model
|
||||||
|
|
||||||
|
- running on Ascend
|
||||||
|
|
||||||
|
```
|
||||||
|
# Load dataset
|
||||||
|
dataset = create_dataset(cfg.data_path, 1)
|
||||||
|
batch_num = dataset.get_dataset_size()
|
||||||
|
|
||||||
|
# Define model
|
||||||
|
net = CNNCTC(cfg.NUM_CLASS, cfg.HIDDEN_SIZE, cfg.FINAL_FEATURE_WIDTH)
|
||||||
|
# Continue training if set pre_trained to be True
|
||||||
|
if cfg.pre_trained:
|
||||||
|
param_dict = load_checkpoint(cfg.checkpoint_path)
|
||||||
|
load_param_into_net(net, param_dict)
|
||||||
|
lr = lr_steps(0, lr_max=cfg.lr_init, total_epochs=cfg.epoch_size,
|
||||||
|
steps_per_epoch=batch_num)
|
||||||
|
opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()),
|
||||||
|
Tensor(lr), cfg.momentum, weight_decay=cfg.weight_decay)
|
||||||
|
loss = P.CTCLoss(preprocess_collapse_repeated=False,
|
||||||
|
ctc_merge_repeated=True,
|
||||||
|
ignore_longer_outputs_than_inputs=False)
|
||||||
|
model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'},
|
||||||
|
amp_level="O2", keep_batchnorm_fp32=False, loss_scale_manager=None)
|
||||||
|
|
||||||
|
# Set callbacks
|
||||||
|
config_ck = CheckpointConfig(save_checkpoint_steps=batch_num * 5,
|
||||||
|
keep_checkpoint_max=cfg.keep_checkpoint_max)
|
||||||
|
time_cb = TimeMonitor(data_size=batch_num)
|
||||||
|
ckpoint_cb = ModelCheckpoint(prefix="train_googlenet_cifar10", directory="./",
|
||||||
|
config=config_ck)
|
||||||
|
loss_cb = LossMonitor()
|
||||||
|
|
||||||
|
# Start training
|
||||||
|
model.train(cfg.epoch_size, dataset, callbacks=[time_cb, ckpoint_cb, loss_cb])
|
||||||
|
print("train success")
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
# [ModelZoo Homepage](#contents)
|
||||||
|
Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo).
|
|
@ -0,0 +1,109 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""cnnctc eval"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import time
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from mindspore import Tensor, context
|
||||||
|
import mindspore.common.dtype as mstype
|
||||||
|
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||||
|
from mindspore.dataset import GeneratorDataset
|
||||||
|
|
||||||
|
from src.util import CTCLabelConverter, AverageMeter
|
||||||
|
from src.config import Config_CNNCTC
|
||||||
|
from src.dataset import IIIT_Generator_batch
|
||||||
|
from src.cnn_ctc import CNNCTC_Model
|
||||||
|
|
||||||
|
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False,
|
||||||
|
save_graphs_path=".", enable_auto_mixed_precision=False)
|
||||||
|
|
||||||
|
def test_dataset_creator():
|
||||||
|
ds = GeneratorDataset(IIIT_Generator_batch, ['img', 'label_indices', 'text', 'sequence_length', 'label_str'])
|
||||||
|
return ds
|
||||||
|
|
||||||
|
|
||||||
|
def test(config):
|
||||||
|
ds = test_dataset_creator()
|
||||||
|
|
||||||
|
net = CNNCTC_Model(config.NUM_CLASS, config.HIDDEN_SIZE, config.FINAL_FEATURE_WIDTH)
|
||||||
|
|
||||||
|
ckpt_path = config.CKPT_PATH
|
||||||
|
param_dict = load_checkpoint(ckpt_path)
|
||||||
|
load_param_into_net(net, param_dict)
|
||||||
|
print('parameters loaded! from: ', ckpt_path)
|
||||||
|
|
||||||
|
converter = CTCLabelConverter(config.CHARACTER)
|
||||||
|
|
||||||
|
model_run_time = AverageMeter()
|
||||||
|
npu_to_cpu_time = AverageMeter()
|
||||||
|
postprocess_time = AverageMeter()
|
||||||
|
|
||||||
|
count = 0
|
||||||
|
correct_count = 0
|
||||||
|
for data in ds.create_tuple_iterator():
|
||||||
|
img, _, text, _, length = data
|
||||||
|
|
||||||
|
img_tensor = Tensor(img, mstype.float32)
|
||||||
|
|
||||||
|
model_run_begin = time.time()
|
||||||
|
model_predict = net(img_tensor)
|
||||||
|
model_run_end = time.time()
|
||||||
|
model_run_time.update(model_run_end - model_run_begin)
|
||||||
|
|
||||||
|
npu_to_cpu_begin = time.time()
|
||||||
|
model_predict = np.squeeze(model_predict.asnumpy())
|
||||||
|
npu_to_cpu_end = time.time()
|
||||||
|
npu_to_cpu_time.update(npu_to_cpu_end - npu_to_cpu_begin)
|
||||||
|
|
||||||
|
postprocess_begin = time.time()
|
||||||
|
preds_size = np.array([model_predict.shape[1]] * config.TEST_BATCH_SIZE)
|
||||||
|
preds_index = np.argmax(model_predict, 2)
|
||||||
|
preds_index = np.reshape(preds_index, [-1])
|
||||||
|
preds_str = converter.decode(preds_index, preds_size)
|
||||||
|
postprocess_end = time.time()
|
||||||
|
postprocess_time.update(postprocess_end - postprocess_begin)
|
||||||
|
|
||||||
|
label_str = converter.reverse_encode(text.asnumpy(), length.asnumpy())
|
||||||
|
|
||||||
|
if count == 0:
|
||||||
|
model_run_time.reset()
|
||||||
|
npu_to_cpu_time.reset()
|
||||||
|
postprocess_time.reset()
|
||||||
|
else:
|
||||||
|
print('---------model run time--------', model_run_time.avg)
|
||||||
|
print('---------npu_to_cpu run time--------', npu_to_cpu_time.avg)
|
||||||
|
print('---------postprocess run time--------', postprocess_time.avg)
|
||||||
|
|
||||||
|
print("Prediction samples: \n", preds_str[:5])
|
||||||
|
print("Ground truth: \n", label_str[:5])
|
||||||
|
for pred, label in zip(preds_str, label_str):
|
||||||
|
if pred == label:
|
||||||
|
correct_count += 1
|
||||||
|
count += 1
|
||||||
|
|
||||||
|
print('accuracy: ', correct_count / count)
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
parser = argparse.ArgumentParser(description="FasterRcnn training")
|
||||||
|
parser.add_argument('--device_id', type=int, default=0, help="Device id, default is 0.")
|
||||||
|
parser.add_argument("--ckpt_path", type=str, default="", help="trained file path.")
|
||||||
|
args_opt = parser.parse_args()
|
||||||
|
|
||||||
|
cfg = Config_CNNCTC()
|
||||||
|
if args_opt.ckpt_path != "":
|
||||||
|
cfg.CKPT_PATH = args_opt.ckpt_path
|
||||||
|
test(cfg)
|
|
@ -0,0 +1,57 @@
|
||||||
|
#!/bin/bash
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
current_exec_path=$(pwd)
|
||||||
|
echo ${current_exec_path}
|
||||||
|
|
||||||
|
get_real_path(){
|
||||||
|
if [ "${1:0:1}" == "/" ]; then
|
||||||
|
echo "$1"
|
||||||
|
else
|
||||||
|
echo "$(realpath -m $PWD/$1)"
|
||||||
|
fi
|
||||||
|
}
|
||||||
|
|
||||||
|
PATH1=$(get_real_path $1)
|
||||||
|
echo $PATH1
|
||||||
|
|
||||||
|
PATH2=$(get_real_path $2)
|
||||||
|
echo $PATH2
|
||||||
|
|
||||||
|
python ${current_exec_path}/src/generate_hccn_file.py --rank_file=$PATH1
|
||||||
|
export RANK_TABLE_FILE=$PATH1
|
||||||
|
export RANK_SIZE=8
|
||||||
|
ulimit -u unlimited
|
||||||
|
for((i=0;i<$RANK_SIZE;i++));
|
||||||
|
do
|
||||||
|
rm ./train_parallel_$i/ -rf
|
||||||
|
mkdir ./train_parallel_$i
|
||||||
|
cp ./*.py ./train_parallel_$i
|
||||||
|
cp ./scripts/*.sh ./train_parallel_$i
|
||||||
|
cp -r ./src ./train_parallel_$i
|
||||||
|
cd ./train_parallel_$i || exit
|
||||||
|
export RANK_ID=$i
|
||||||
|
export DEVICE_ID=$i
|
||||||
|
echo "start training for rank $RANK_ID, device $DEVICE_ID"
|
||||||
|
if [ -f $PATH2 ]
|
||||||
|
then
|
||||||
|
python train.py --device_id=$i --ckpt_path=$PATH2 --run_distribute=True >log_$i.log 2>&1 &
|
||||||
|
else
|
||||||
|
python train.py --device_id=$i --run_distribute=True >log_$i.log 2>&1 &
|
||||||
|
fi
|
||||||
|
cd .. || exit
|
||||||
|
done
|
||||||
|
|
|
@ -0,0 +1,54 @@
|
||||||
|
#!/bin/bash
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
if [ $# -ne 1 ]
|
||||||
|
then
|
||||||
|
echo "Usage: sh run_eval_ascend.sh [TRAINED_CKPT]"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
get_real_path(){
|
||||||
|
if [ "${1:0:1}" == "/" ]; then
|
||||||
|
echo "$1"
|
||||||
|
else
|
||||||
|
echo "$(realpath -m $PWD/$1)"
|
||||||
|
fi
|
||||||
|
}
|
||||||
|
|
||||||
|
PATH1=$(get_real_path $1)
|
||||||
|
echo $PATH1
|
||||||
|
if [ ! -f $PATH1 ]
|
||||||
|
then
|
||||||
|
echo "error: TRAINED_CKPT=$PATH1 is not a file"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
ulimit -u unlimited
|
||||||
|
export DEVICE_ID=0
|
||||||
|
|
||||||
|
if [ -d "eval" ];
|
||||||
|
then
|
||||||
|
rm -rf ./eval
|
||||||
|
fi
|
||||||
|
mkdir ./eval
|
||||||
|
cp ./*.py ./eval
|
||||||
|
cp ./scripts/*.sh ./eval
|
||||||
|
cp -r ./src ./eval
|
||||||
|
cd ./eval || exit
|
||||||
|
echo "start infering for device $DEVICE_ID"
|
||||||
|
env > env.log
|
||||||
|
python eval.py --device_id=$DEVICE_ID --ckpt_path=$PATH1 &> log &
|
||||||
|
cd .. || exit
|
|
@ -0,0 +1,45 @@
|
||||||
|
#!/bin/bash
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
get_real_path(){
|
||||||
|
if [ "${1:0:1}" == "/" ]; then
|
||||||
|
echo "$1"
|
||||||
|
else
|
||||||
|
echo "$(realpath -m $PWD/$1)"
|
||||||
|
fi
|
||||||
|
}
|
||||||
|
PATH1=$(get_real_path $1)
|
||||||
|
|
||||||
|
ulimit -u unlimited
|
||||||
|
|
||||||
|
if [ -d "train" ];
|
||||||
|
then
|
||||||
|
rm -rf ./train
|
||||||
|
fi
|
||||||
|
mkdir ./train
|
||||||
|
cp ./*.py ./train
|
||||||
|
cp ./scripts/*.sh ./train
|
||||||
|
cp -r ./src ./train
|
||||||
|
cd ./train || exit
|
||||||
|
echo "start training for device $DEVICE_ID"
|
||||||
|
env > env.log
|
||||||
|
if [ -f $PATH1 ]
|
||||||
|
then
|
||||||
|
python train.py --device_id=$DEVICE_ID --ckpt_path=$PATH1 --run_distribute=False &> log &
|
||||||
|
else
|
||||||
|
python train.py --device_id=$DEVICE_ID --run_distribute=False &> log &
|
||||||
|
fi
|
||||||
|
cd .. || exit
|
|
@ -0,0 +1,15 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""src init file"""
|
|
@ -0,0 +1,71 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""loss callback"""
|
||||||
|
|
||||||
|
import time
|
||||||
|
from mindspore.train.callback import Callback
|
||||||
|
from .util import AverageMeter
|
||||||
|
|
||||||
|
class LossCallBack(Callback):
|
||||||
|
"""
|
||||||
|
Monitor the loss in training.
|
||||||
|
|
||||||
|
If the loss is NAN or INF terminating training.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
If per_print_times is 0 do not print loss.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
per_print_times (int): Print loss every times. Default: 1.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, per_print_times=1):
|
||||||
|
super(LossCallBack, self).__init__()
|
||||||
|
if not isinstance(per_print_times, int) or per_print_times < 0:
|
||||||
|
raise ValueError("print_step must be int and >= 0.")
|
||||||
|
self._per_print_times = per_print_times
|
||||||
|
self.loss_avg = AverageMeter()
|
||||||
|
self.timer = AverageMeter()
|
||||||
|
self.start_time = time.time()
|
||||||
|
|
||||||
|
def step_end(self, run_context):
|
||||||
|
cb_params = run_context.original_args()
|
||||||
|
|
||||||
|
loss = cb_params.net_outputs.asnumpy()
|
||||||
|
|
||||||
|
cur_step_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num + 1
|
||||||
|
cur_num = cb_params.cur_step_num
|
||||||
|
|
||||||
|
if cur_step_in_epoch % 2000 == 1:
|
||||||
|
self.loss_avg = AverageMeter()
|
||||||
|
self.timer = AverageMeter()
|
||||||
|
self.start_time = time.time()
|
||||||
|
else:
|
||||||
|
self.timer.update(time.time() - self.start_time)
|
||||||
|
self.start_time = time.time()
|
||||||
|
|
||||||
|
self.loss_avg.update(loss)
|
||||||
|
|
||||||
|
if self._per_print_times != 0 and cur_num % self._per_print_times == 0:
|
||||||
|
loss_file = open("./loss.log", "a+")
|
||||||
|
loss_file.write("epoch: %s step: %s , loss is %s, average time per step is %s" % (
|
||||||
|
cb_params.cur_epoch_num, cur_step_in_epoch,
|
||||||
|
self.loss_avg.avg, self.timer.avg))
|
||||||
|
loss_file.write("\n")
|
||||||
|
loss_file.close()
|
||||||
|
|
||||||
|
print("epoch: %s step: %s , loss is %s, average time per step is %s" % (
|
||||||
|
cb_params.cur_epoch_num, cur_step_in_epoch,
|
||||||
|
self.loss_avg.avg, self.timer.avg))
|
|
@ -0,0 +1,255 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""cnn_ctc network define"""
|
||||||
|
|
||||||
|
import mindspore.nn as nn
|
||||||
|
from mindspore.ops import operations as P
|
||||||
|
from mindspore.common.initializer import TruncatedNormal, initializer
|
||||||
|
import mindspore.common.dtype as mstype
|
||||||
|
|
||||||
|
class CNNCTC_Model(nn.Cell):
|
||||||
|
|
||||||
|
def __init__(self, num_class, hidden_size, final_feature_width):
|
||||||
|
super(CNNCTC_Model, self).__init__()
|
||||||
|
|
||||||
|
self.num_class = num_class
|
||||||
|
self.hidden_size = hidden_size
|
||||||
|
self.final_feature_width = final_feature_width
|
||||||
|
|
||||||
|
self.FeatureExtraction = ResNet_FeatureExtractor()
|
||||||
|
self.Prediction = nn.Dense(self.hidden_size, self.num_class)
|
||||||
|
|
||||||
|
self.transpose = P.Transpose()
|
||||||
|
self.reshape = P.Reshape()
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
x = self.FeatureExtraction(x)
|
||||||
|
x = self.transpose(x, (0, 3, 1, 2)) # [b, c, h, w] -> [b, w, c, h]
|
||||||
|
|
||||||
|
x = self.reshape(x, (-1, self.hidden_size))
|
||||||
|
x = self.Prediction(x)
|
||||||
|
x = self.reshape(x, (-1, self.final_feature_width, self.num_class))
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class WithLossCell(nn.Cell):
|
||||||
|
|
||||||
|
def __init__(self, backbone, loss_fn):
|
||||||
|
super(WithLossCell, self).__init__(auto_prefix=False)
|
||||||
|
self._backbone = backbone
|
||||||
|
self._loss_fn = loss_fn
|
||||||
|
|
||||||
|
def construct(self, img, label_indices, text, sequence_length):
|
||||||
|
model_predict = self._backbone(img)
|
||||||
|
return self._loss_fn(model_predict, label_indices, text, sequence_length)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def backbone_network(self):
|
||||||
|
return self._backbone
|
||||||
|
|
||||||
|
class ctc_loss(nn.Cell):
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
super(ctc_loss, self).__init__()
|
||||||
|
|
||||||
|
self.loss = P.CTCLoss(preprocess_collapse_repeated=False,
|
||||||
|
ctc_merge_repeated=True,
|
||||||
|
ignore_longer_outputs_than_inputs=False)
|
||||||
|
|
||||||
|
self.mean = P.ReduceMean()
|
||||||
|
self.transpose = P.Transpose()
|
||||||
|
self.reshape = P.Reshape()
|
||||||
|
|
||||||
|
def construct(self, inputs, labels_indices, labels_values, sequence_length):
|
||||||
|
inputs = self.transpose(inputs, (1, 0, 2))
|
||||||
|
|
||||||
|
loss, _ = self.loss(inputs, labels_indices, labels_values, sequence_length)
|
||||||
|
|
||||||
|
loss = self.mean(loss)
|
||||||
|
return loss
|
||||||
|
|
||||||
|
|
||||||
|
class ResNet_FeatureExtractor(nn.Cell):
|
||||||
|
def __init__(self):
|
||||||
|
super(ResNet_FeatureExtractor, self).__init__()
|
||||||
|
self.ConvNet = ResNet(3, 512, BasicBlock, [1, 2, 5, 3])
|
||||||
|
|
||||||
|
def construct(self, featuremap):
|
||||||
|
return self.ConvNet(featuremap)
|
||||||
|
|
||||||
|
|
||||||
|
class ResNet(nn.Cell):
|
||||||
|
def __init__(self, input_channel, output_channel, block, layers):
|
||||||
|
super(ResNet, self).__init__()
|
||||||
|
|
||||||
|
self.output_channel_block = [int(output_channel / 4), int(output_channel / 2), output_channel, output_channel]
|
||||||
|
|
||||||
|
self.inplanes = int(output_channel / 8)
|
||||||
|
self.conv0_1 = ms_conv3x3(input_channel, int(output_channel / 16), stride=1, padding=1, pad_mode='pad')
|
||||||
|
self.bn0_1 = ms_fused_bn(int(output_channel / 16))
|
||||||
|
self.conv0_2 = ms_conv3x3(int(output_channel / 16), self.inplanes, stride=1, padding=1, pad_mode='pad')
|
||||||
|
self.bn0_2 = ms_fused_bn(self.inplanes)
|
||||||
|
self.relu = P.ReLU()
|
||||||
|
|
||||||
|
self.maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2, pad_mode='valid')
|
||||||
|
self.layer1 = self._make_layer(block, self.output_channel_block[0], layers[0])
|
||||||
|
self.conv1 = ms_conv3x3(self.output_channel_block[0], self.output_channel_block[0], stride=1, padding=1,
|
||||||
|
pad_mode='pad')
|
||||||
|
self.bn1 = ms_fused_bn(self.output_channel_block[0])
|
||||||
|
|
||||||
|
self.maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2, pad_mode='valid')
|
||||||
|
self.layer2 = self._make_layer(block, self.output_channel_block[1], layers[1])
|
||||||
|
self.conv2 = ms_conv3x3(self.output_channel_block[1], self.output_channel_block[1], stride=1, padding=1,
|
||||||
|
pad_mode='pad')
|
||||||
|
self.bn2 = ms_fused_bn(self.output_channel_block[1])
|
||||||
|
|
||||||
|
self.pad = P.Pad(((0, 0), (0, 0), (0, 0), (1, 1)))
|
||||||
|
self.maxpool3 = nn.MaxPool2d(kernel_size=2, stride=(2, 1), pad_mode='valid')
|
||||||
|
self.layer3 = self._make_layer(block, self.output_channel_block[2], layers[2])
|
||||||
|
self.conv3 = ms_conv3x3(self.output_channel_block[2], self.output_channel_block[2], stride=1, padding=1,
|
||||||
|
pad_mode='pad')
|
||||||
|
self.bn3 = ms_fused_bn(self.output_channel_block[2])
|
||||||
|
|
||||||
|
self.layer4 = self._make_layer(block, self.output_channel_block[3], layers[3])
|
||||||
|
self.conv4_1 = ms_conv2x2(self.output_channel_block[3], self.output_channel_block[3], stride=(2, 1),
|
||||||
|
pad_mode='valid')
|
||||||
|
self.bn4_1 = ms_fused_bn(self.output_channel_block[3])
|
||||||
|
|
||||||
|
self.conv4_2 = ms_conv2x2(self.output_channel_block[3], self.output_channel_block[3], stride=1, padding=0,
|
||||||
|
pad_mode='valid')
|
||||||
|
self.bn4_2 = ms_fused_bn(self.output_channel_block[3])
|
||||||
|
|
||||||
|
def _make_layer(self, block, planes, blocks, stride=1):
|
||||||
|
downsample = None
|
||||||
|
if stride != 1 or self.inplanes != planes * block.expansion:
|
||||||
|
downsample = nn.SequentialCell(
|
||||||
|
[ms_conv1x1(self.inplanes, planes * block.expansion, stride=stride),
|
||||||
|
ms_fused_bn(planes * block.expansion)]
|
||||||
|
)
|
||||||
|
|
||||||
|
layers = []
|
||||||
|
layers.append(block(self.inplanes, planes, stride, downsample))
|
||||||
|
self.inplanes = planes * block.expansion
|
||||||
|
for _ in range(1, blocks):
|
||||||
|
layers.append(block(self.inplanes, planes))
|
||||||
|
|
||||||
|
return nn.SequentialCell(layers)
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
x = self.conv0_1(x)
|
||||||
|
x = self.bn0_1(x)
|
||||||
|
x = self.relu(x)
|
||||||
|
x = self.conv0_2(x)
|
||||||
|
x = self.bn0_2(x)
|
||||||
|
x = self.relu(x)
|
||||||
|
|
||||||
|
x = self.maxpool1(x)
|
||||||
|
x = self.layer1(x)
|
||||||
|
x = self.conv1(x)
|
||||||
|
x = self.bn1(x)
|
||||||
|
x = self.relu(x)
|
||||||
|
|
||||||
|
x = self.maxpool2(x)
|
||||||
|
x = self.layer2(x)
|
||||||
|
x = self.conv2(x)
|
||||||
|
x = self.bn2(x)
|
||||||
|
x = self.relu(x)
|
||||||
|
|
||||||
|
x = self.pad(x)
|
||||||
|
x = self.maxpool3(x)
|
||||||
|
x = self.layer3(x)
|
||||||
|
x = self.conv3(x)
|
||||||
|
x = self.bn3(x)
|
||||||
|
x = self.relu(x)
|
||||||
|
|
||||||
|
x = self.layer4(x)
|
||||||
|
x = self.pad(x)
|
||||||
|
x = self.conv4_1(x)
|
||||||
|
x = self.bn4_1(x)
|
||||||
|
x = self.relu(x)
|
||||||
|
x = self.conv4_2(x)
|
||||||
|
x = self.bn4_2(x)
|
||||||
|
x = self.relu(x)
|
||||||
|
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class BasicBlock(nn.Cell):
|
||||||
|
expansion = 1
|
||||||
|
|
||||||
|
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
||||||
|
super(BasicBlock, self).__init__()
|
||||||
|
|
||||||
|
self.conv1 = ms_conv3x3(inplanes, planes, stride=stride, padding=1, pad_mode='pad')
|
||||||
|
self.bn1 = ms_fused_bn(planes)
|
||||||
|
self.conv2 = ms_conv3x3(planes, planes, stride=stride, padding=1, pad_mode='pad')
|
||||||
|
self.bn2 = ms_fused_bn(planes)
|
||||||
|
self.relu = P.ReLU()
|
||||||
|
self.downsample = downsample
|
||||||
|
self.add = P.TensorAdd()
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
residual = x
|
||||||
|
|
||||||
|
out = self.conv1(x)
|
||||||
|
out = self.bn1(out)
|
||||||
|
out = self.relu(out)
|
||||||
|
|
||||||
|
out = self.conv2(out)
|
||||||
|
out = self.bn2(out)
|
||||||
|
|
||||||
|
if self.downsample is not None:
|
||||||
|
residual = self.downsample(x)
|
||||||
|
out = self.add(out, residual)
|
||||||
|
out = self.relu(out)
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def weight_variable(shape, factor=0.1, half_precision=False):
|
||||||
|
if half_precision:
|
||||||
|
return initializer(TruncatedNormal(0.02), shape, dtype=mstype.float16)
|
||||||
|
|
||||||
|
return TruncatedNormal(0.02)
|
||||||
|
|
||||||
|
|
||||||
|
def ms_conv3x3(in_channels, out_channels, stride=1, padding=0, pad_mode='same', has_bias=False):
|
||||||
|
"""Get a conv2d layer with 3x3 kernel size."""
|
||||||
|
init_value = weight_variable((out_channels, in_channels, 3, 3))
|
||||||
|
return nn.Conv2d(in_channels, out_channels,
|
||||||
|
kernel_size=3, stride=stride, padding=padding, pad_mode=pad_mode, weight_init=init_value,
|
||||||
|
has_bias=has_bias)
|
||||||
|
|
||||||
|
|
||||||
|
def ms_conv1x1(in_channels, out_channels, stride=1, padding=0, pad_mode='same', has_bias=False):
|
||||||
|
"""Get a conv2d layer with 1x1 kernel size."""
|
||||||
|
init_value = weight_variable((out_channels, in_channels, 1, 1))
|
||||||
|
return nn.Conv2d(in_channels, out_channels,
|
||||||
|
kernel_size=1, stride=stride, padding=padding, pad_mode=pad_mode, weight_init=init_value,
|
||||||
|
has_bias=has_bias)
|
||||||
|
|
||||||
|
|
||||||
|
def ms_conv2x2(in_channels, out_channels, stride=1, padding=0, pad_mode='same', has_bias=False):
|
||||||
|
"""Get a conv2d layer with 2x2 kernel size."""
|
||||||
|
init_value = weight_variable((out_channels, in_channels, 1, 1))
|
||||||
|
return nn.Conv2d(in_channels, out_channels,
|
||||||
|
kernel_size=2, stride=stride, padding=padding, pad_mode=pad_mode, weight_init=init_value,
|
||||||
|
has_bias=has_bias)
|
||||||
|
|
||||||
|
|
||||||
|
def ms_fused_bn(channels, momentum=0.1):
|
||||||
|
"""Get a fused batchnorm"""
|
||||||
|
return nn.BatchNorm2d(channels, momentum=momentum)
|
|
@ -0,0 +1,43 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""network config setting, will be used in train.py and eval.py"""
|
||||||
|
|
||||||
|
class Config_CNNCTC():
|
||||||
|
# model config
|
||||||
|
CHARACTER = '0123456789abcdefghijklmnopqrstuvwxyz'
|
||||||
|
NUM_CLASS = len(CHARACTER) + 1
|
||||||
|
HIDDEN_SIZE = 512
|
||||||
|
FINAL_FEATURE_WIDTH = 26
|
||||||
|
|
||||||
|
# dataset config
|
||||||
|
IMG_H = 32
|
||||||
|
IMG_W = 100
|
||||||
|
TRAIN_DATASET_PATH = 'CNNCTC_Data/ST_MJ/'
|
||||||
|
TRAIN_DATASET_INDEX_PATH = 'CNNCTC_Data/st_mj_fixed_length_index_list.pkl'
|
||||||
|
TRAIN_BATCH_SIZE = 192
|
||||||
|
TEST_DATASET_PATH = 'CNNCTC_Data/IIIT5k_3000'
|
||||||
|
TEST_BATCH_SIZE = 256
|
||||||
|
TEST_DATASET_SIZE = 2976
|
||||||
|
TRAIN_EPOCHS = 3
|
||||||
|
|
||||||
|
# training config
|
||||||
|
CKPT_PATH = ''
|
||||||
|
SAVE_PATH = './'
|
||||||
|
LR = 1e-4
|
||||||
|
LR_PARA = 5e-4
|
||||||
|
MOMENTUM = 0.8
|
||||||
|
LOSS_SCALE = 8096
|
||||||
|
SAVE_CKPT_PER_N_STEP = 2000
|
||||||
|
KEEP_CKPT_MAX_NUM = 5
|
|
@ -0,0 +1,265 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""cnn_ctc dataset"""
|
||||||
|
|
||||||
|
import sys
|
||||||
|
import pickle
|
||||||
|
import math
|
||||||
|
import six
|
||||||
|
import numpy as np
|
||||||
|
from PIL import Image
|
||||||
|
import lmdb
|
||||||
|
|
||||||
|
from mindspore.communication.management import get_rank, get_group_size
|
||||||
|
|
||||||
|
from .util import CTCLabelConverter
|
||||||
|
from .config import Config_CNNCTC
|
||||||
|
|
||||||
|
config = Config_CNNCTC()
|
||||||
|
|
||||||
|
class NormalizePAD():
|
||||||
|
|
||||||
|
def __init__(self, max_size, PAD_type='right'):
|
||||||
|
self.max_size = max_size
|
||||||
|
self.PAD_type = PAD_type
|
||||||
|
|
||||||
|
def __call__(self, img):
|
||||||
|
# toTensor
|
||||||
|
img = np.array(img, dtype=np.float32)
|
||||||
|
img = img.transpose([2, 0, 1])
|
||||||
|
img = img.astype(np.float)
|
||||||
|
img = np.true_divide(img, 255)
|
||||||
|
# normalize
|
||||||
|
img = np.subtract(img, 0.5)
|
||||||
|
img = np.true_divide(img, 0.5)
|
||||||
|
|
||||||
|
_, _, w = img.shape
|
||||||
|
Pad_img = np.zeros(shape=self.max_size, dtype=np.float32)
|
||||||
|
Pad_img[:, :, :w] = img # right pad
|
||||||
|
if self.max_size[2] != w: # add border Pad
|
||||||
|
Pad_img[:, :, w:] = np.tile(np.expand_dims(img[:, :, w - 1], 2), (1, 1, self.max_size[2] - w))
|
||||||
|
|
||||||
|
return Pad_img
|
||||||
|
|
||||||
|
|
||||||
|
class AlignCollate():
|
||||||
|
|
||||||
|
def __init__(self, imgH=32, imgW=100):
|
||||||
|
self.imgH = imgH
|
||||||
|
self.imgW = imgW
|
||||||
|
|
||||||
|
def __call__(self, images):
|
||||||
|
|
||||||
|
resized_max_w = self.imgW
|
||||||
|
input_channel = 3
|
||||||
|
transform = NormalizePAD((input_channel, self.imgH, resized_max_w))
|
||||||
|
|
||||||
|
resized_images = []
|
||||||
|
for image in images:
|
||||||
|
w, h = image.size
|
||||||
|
ratio = w / float(h)
|
||||||
|
if math.ceil(self.imgH * ratio) > self.imgW:
|
||||||
|
resized_w = self.imgW
|
||||||
|
else:
|
||||||
|
resized_w = math.ceil(self.imgH * ratio)
|
||||||
|
|
||||||
|
resized_image = image.resize((resized_w, self.imgH), Image.BICUBIC)
|
||||||
|
resized_images.append(transform(resized_image))
|
||||||
|
|
||||||
|
image_tensors = np.concatenate([np.expand_dims(t, 0) for t in resized_images], 0)
|
||||||
|
|
||||||
|
return image_tensors
|
||||||
|
|
||||||
|
|
||||||
|
def get_img_from_lmdb(env, index):
|
||||||
|
with env.begin(write=False) as txn:
|
||||||
|
label_key = 'label-%09d'.encode() % index
|
||||||
|
label = txn.get(label_key).decode('utf-8')
|
||||||
|
img_key = 'image-%09d'.encode() % index
|
||||||
|
imgbuf = txn.get(img_key)
|
||||||
|
|
||||||
|
buf = six.BytesIO()
|
||||||
|
buf.write(imgbuf)
|
||||||
|
buf.seek(0)
|
||||||
|
try:
|
||||||
|
img = Image.open(buf).convert('RGB') # for color image
|
||||||
|
|
||||||
|
except IOError:
|
||||||
|
print(f'Corrupted image for {index}')
|
||||||
|
# make dummy image and dummy label for corrupted image.
|
||||||
|
img = Image.new('RGB', (config.IMG_W, config.IMG_H))
|
||||||
|
label = '[dummy_label]'
|
||||||
|
|
||||||
|
label = label.lower()
|
||||||
|
|
||||||
|
return img, label
|
||||||
|
|
||||||
|
|
||||||
|
class ST_MJ_Generator_batch_fixed_length:
|
||||||
|
def __init__(self):
|
||||||
|
self.align_collector = AlignCollate()
|
||||||
|
self.converter = CTCLabelConverter(config.CHARACTER)
|
||||||
|
self.env = lmdb.open(config.TRAIN_DATASET_PATH, max_readers=32, readonly=True, lock=False, readahead=False,
|
||||||
|
meminit=False)
|
||||||
|
if not self.env:
|
||||||
|
print('cannot create lmdb from %s' % (config.TRAIN_DATASET_PATH))
|
||||||
|
raise ValueError(config.TRAIN_DATASET_PATH)
|
||||||
|
|
||||||
|
with open(config.TRAIN_DATASET_INDEX_PATH, 'rb') as f:
|
||||||
|
self.st_mj_filtered_index_list = pickle.load(f)
|
||||||
|
|
||||||
|
print(f'num of samples in ST_MJ dataset: {len(self.st_mj_filtered_index_list)}')
|
||||||
|
self.dataset_size = len(self.st_mj_filtered_index_list) // config.TRAIN_BATCH_SIZE
|
||||||
|
self.batch_size = config.TRAIN_BATCH_SIZE
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return self.dataset_size
|
||||||
|
|
||||||
|
def __getitem__(self, item):
|
||||||
|
img_ret = []
|
||||||
|
text_ret = []
|
||||||
|
|
||||||
|
for i in range(item * self.batch_size, (item + 1) * self.batch_size):
|
||||||
|
index = self.st_mj_filtered_index_list[i]
|
||||||
|
img, label = get_img_from_lmdb(self.env, index)
|
||||||
|
|
||||||
|
img_ret.append(img)
|
||||||
|
text_ret.append(label)
|
||||||
|
|
||||||
|
img_ret = self.align_collector(img_ret)
|
||||||
|
text_ret, length = self.converter.encode(text_ret)
|
||||||
|
|
||||||
|
label_indices = []
|
||||||
|
for i, _ in enumerate(length):
|
||||||
|
for j in range(length[i]):
|
||||||
|
label_indices.append((i, j))
|
||||||
|
label_indices = np.array(label_indices, np.int64)
|
||||||
|
sequence_length = np.array([config.FINAL_FEATURE_WIDTH] * config.TRAIN_BATCH_SIZE, dtype=np.int32)
|
||||||
|
text_ret = text_ret.astype(np.int32)
|
||||||
|
|
||||||
|
return img_ret, label_indices, text_ret, sequence_length
|
||||||
|
|
||||||
|
class ST_MJ_Generator_batch_fixed_length_para:
|
||||||
|
def __init__(self):
|
||||||
|
self.align_collector = AlignCollate()
|
||||||
|
self.converter = CTCLabelConverter(config.CHARACTER)
|
||||||
|
self.env = lmdb.open(config.TRAIN_DATASET_PATH, max_readers=32, readonly=True, lock=False, readahead=False,
|
||||||
|
meminit=False)
|
||||||
|
if not self.env:
|
||||||
|
print('cannot create lmdb from %s' % (config.TRAIN_DATASET_PATH))
|
||||||
|
raise ValueError(config.TRAIN_DATASET_PATH)
|
||||||
|
|
||||||
|
with open(config.TRAIN_DATASET_INDEX_PATH, 'rb') as f:
|
||||||
|
self.st_mj_filtered_index_list = pickle.load(f)
|
||||||
|
|
||||||
|
print(f'num of samples in ST_MJ dataset: {len(self.st_mj_filtered_index_list)}')
|
||||||
|
self.rank_id = get_rank()
|
||||||
|
self.rank_size = get_group_size()
|
||||||
|
self.dataset_size = len(self.st_mj_filtered_index_list) // config.TRAIN_BATCH_SIZE // self.rank_size
|
||||||
|
self.batch_size = config.TRAIN_BATCH_SIZE
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return self.dataset_size
|
||||||
|
|
||||||
|
def __getitem__(self, item):
|
||||||
|
img_ret = []
|
||||||
|
text_ret = []
|
||||||
|
|
||||||
|
rank_item = (item * self.rank_size) + self.rank_id
|
||||||
|
for i in range(rank_item * self.batch_size, (rank_item + 1) * self.batch_size):
|
||||||
|
index = self.st_mj_filtered_index_list[i]
|
||||||
|
img, label = get_img_from_lmdb(self.env, index)
|
||||||
|
|
||||||
|
img_ret.append(img)
|
||||||
|
text_ret.append(label)
|
||||||
|
|
||||||
|
img_ret = self.align_collector(img_ret)
|
||||||
|
text_ret, length = self.converter.encode(text_ret)
|
||||||
|
|
||||||
|
label_indices = []
|
||||||
|
for i, _ in enumerate(length):
|
||||||
|
for j in range(length[i]):
|
||||||
|
label_indices.append((i, j))
|
||||||
|
label_indices = np.array(label_indices, np.int64)
|
||||||
|
sequence_length = np.array([config.FINAL_FEATURE_WIDTH] * config.TRAIN_BATCH_SIZE, dtype=np.int32)
|
||||||
|
text_ret = text_ret.astype(np.int32)
|
||||||
|
|
||||||
|
return img_ret, label_indices, text_ret, sequence_length
|
||||||
|
|
||||||
|
|
||||||
|
def IIIT_Generator_batch():
|
||||||
|
max_len = int((26 + 1) // 2)
|
||||||
|
|
||||||
|
align_collector = AlignCollate()
|
||||||
|
|
||||||
|
converter = CTCLabelConverter(config.CHARACTER)
|
||||||
|
|
||||||
|
env = lmdb.open(config.TEST_DATASET_PATH, max_readers=32, readonly=True, lock=False, readahead=False, meminit=False)
|
||||||
|
if not env:
|
||||||
|
print('cannot create lmdb from %s' % (config.TEST_DATASET_PATH))
|
||||||
|
sys.exit(0)
|
||||||
|
|
||||||
|
with env.begin(write=False) as txn:
|
||||||
|
nSamples = int(txn.get('num-samples'.encode()))
|
||||||
|
nSamples = nSamples
|
||||||
|
|
||||||
|
# Filtering
|
||||||
|
filtered_index_list = []
|
||||||
|
for index in range(nSamples):
|
||||||
|
index += 1 # lmdb starts with 1
|
||||||
|
label_key = 'label-%09d'.encode() % index
|
||||||
|
label = txn.get(label_key).decode('utf-8')
|
||||||
|
|
||||||
|
if len(label) > max_len:
|
||||||
|
continue
|
||||||
|
|
||||||
|
illegal_sample = False
|
||||||
|
for char_item in label.lower():
|
||||||
|
if char_item not in config.CHARACTER:
|
||||||
|
illegal_sample = True
|
||||||
|
break
|
||||||
|
if illegal_sample:
|
||||||
|
continue
|
||||||
|
|
||||||
|
filtered_index_list.append(index)
|
||||||
|
|
||||||
|
img_ret = []
|
||||||
|
text_ret = []
|
||||||
|
|
||||||
|
print(f'num of samples in IIIT dataset: {len(filtered_index_list)}')
|
||||||
|
|
||||||
|
for index in filtered_index_list:
|
||||||
|
|
||||||
|
img, label = get_img_from_lmdb(env, index)
|
||||||
|
|
||||||
|
img_ret.append(img)
|
||||||
|
text_ret.append(label)
|
||||||
|
|
||||||
|
if len(img_ret) == config.TEST_BATCH_SIZE:
|
||||||
|
img_ret = align_collector(img_ret)
|
||||||
|
text_ret, length = converter.encode(text_ret)
|
||||||
|
|
||||||
|
label_indices = []
|
||||||
|
for i, _ in enumerate(length):
|
||||||
|
for j in range(length[i]):
|
||||||
|
label_indices.append((i, j))
|
||||||
|
label_indices = np.array(label_indices, np.int64)
|
||||||
|
sequence_length = np.array([26] * config.TEST_BATCH_SIZE, dtype=np.int32)
|
||||||
|
text_ret = text_ret.astype(np.int32)
|
||||||
|
|
||||||
|
yield img_ret, label_indices, text_ret, sequence_length, length
|
||||||
|
|
||||||
|
img_ret = []
|
||||||
|
text_ret = []
|
|
@ -0,0 +1,88 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""generate ascend rank file"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import socket
|
||||||
|
import argparse
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(description="ascend distribute rank.")
|
||||||
|
parser.add_argument("--rank_file", type=str, default="scripts/rank_table_8p.json", help="rank_tabel_file_path.")
|
||||||
|
|
||||||
|
def main(rank_table_file):
|
||||||
|
nproc_per_node = 8
|
||||||
|
|
||||||
|
visible_devices = ['0', '1', '2', '3', '4', '5', '6', '7']
|
||||||
|
|
||||||
|
server_id = socket.gethostbyname(socket.gethostname())
|
||||||
|
|
||||||
|
hccn_configs = open('/etc/hccn.conf', 'r').readlines()
|
||||||
|
device_ips = {}
|
||||||
|
for hccn_item in hccn_configs:
|
||||||
|
hccn_item = hccn_item.strip()
|
||||||
|
if hccn_item.startswith('address_'):
|
||||||
|
device_id, device_ip = hccn_item.split('=')
|
||||||
|
device_id = device_id.split('_')[1]
|
||||||
|
device_ips[device_id] = device_ip
|
||||||
|
print('device_id:{}, device_ip:{}'.format(device_id, device_ip))
|
||||||
|
|
||||||
|
hccn_table = {}
|
||||||
|
hccn_table['board_id'] = '0x002f' # A+K
|
||||||
|
# hccn_table['board_id'] = '0x0000' # A+X
|
||||||
|
|
||||||
|
hccn_table['chip_info'] = '910'
|
||||||
|
hccn_table['deploy_mode'] = 'lab'
|
||||||
|
hccn_table['group_count'] = '1'
|
||||||
|
hccn_table['group_list'] = []
|
||||||
|
instance_list = []
|
||||||
|
for instance_id in range(nproc_per_node):
|
||||||
|
instance = {}
|
||||||
|
instance['devices'] = []
|
||||||
|
device_id = visible_devices[instance_id]
|
||||||
|
device_ip = device_ips[device_id]
|
||||||
|
instance['devices'].append({
|
||||||
|
'device_id': device_id,
|
||||||
|
'device_ip': device_ip,
|
||||||
|
})
|
||||||
|
instance['rank_id'] = str(instance_id)
|
||||||
|
instance['server_id'] = server_id
|
||||||
|
instance_list.append(instance)
|
||||||
|
hccn_table['group_list'].append({
|
||||||
|
'device_num': str(nproc_per_node),
|
||||||
|
'server_num': '1',
|
||||||
|
'group_name': '',
|
||||||
|
'instance_count': str(nproc_per_node),
|
||||||
|
'instance_list': instance_list,
|
||||||
|
})
|
||||||
|
hccn_table['para_plane_nic_location'] = 'device'
|
||||||
|
hccn_table['para_plane_nic_name'] = []
|
||||||
|
for instance_id in range(nproc_per_node):
|
||||||
|
eth_id = visible_devices[instance_id]
|
||||||
|
hccn_table['para_plane_nic_name'].append('eth{}'.format(eth_id))
|
||||||
|
hccn_table['para_plane_nic_num'] = str(nproc_per_node)
|
||||||
|
hccn_table['status'] = 'completed'
|
||||||
|
import json
|
||||||
|
with open(rank_table_file, 'w') as table_fp:
|
||||||
|
json.dump(hccn_table, table_fp, indent=4)
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
args_opt = parser.parse_args()
|
||||||
|
rank_table = args_opt.rank_file
|
||||||
|
if os.path.exists(rank_table):
|
||||||
|
print('Rank table file exists.')
|
||||||
|
else:
|
||||||
|
print('Generating rank table file.')
|
||||||
|
main(rank_table)
|
||||||
|
print('Rank table file generated')
|
|
@ -0,0 +1,171 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""preprocess dataset"""
|
||||||
|
|
||||||
|
import random
|
||||||
|
import pickle
|
||||||
|
import numpy as np
|
||||||
|
import lmdb
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
|
def combine_lmdbs(lmdb_paths, lmdb_save_path):
|
||||||
|
max_len = int((26 + 1) // 2)
|
||||||
|
character = '0123456789abcdefghijklmnopqrstuvwxyz'
|
||||||
|
|
||||||
|
env_save = lmdb.open(
|
||||||
|
lmdb_save_path,
|
||||||
|
map_size=1099511627776)
|
||||||
|
|
||||||
|
cnt = 0
|
||||||
|
for lmdb_path in lmdb_paths:
|
||||||
|
env = lmdb.open(lmdb_path, max_readers=32, readonly=True, lock=False, readahead=False, meminit=False)
|
||||||
|
with env.begin(write=False) as txn:
|
||||||
|
nSamples = int(txn.get('num-samples'.encode()))
|
||||||
|
nSamples = nSamples
|
||||||
|
|
||||||
|
# Filtering
|
||||||
|
for index in tqdm(range(nSamples)):
|
||||||
|
index += 1 # lmdb starts with 1
|
||||||
|
label_key = 'label-%09d'.encode() % index
|
||||||
|
label = txn.get(label_key).decode('utf-8')
|
||||||
|
|
||||||
|
if len(label) > max_len:
|
||||||
|
continue
|
||||||
|
|
||||||
|
illegal_sample = False
|
||||||
|
for char_item in label.lower():
|
||||||
|
if char_item not in character:
|
||||||
|
illegal_sample = True
|
||||||
|
break
|
||||||
|
if illegal_sample:
|
||||||
|
continue
|
||||||
|
|
||||||
|
img_key = 'image-%09d'.encode() % index
|
||||||
|
imgbuf = txn.get(img_key)
|
||||||
|
|
||||||
|
with env_save.begin(write=True) as txn_save:
|
||||||
|
cnt += 1
|
||||||
|
|
||||||
|
label_key_save = 'label-%09d'.encode() % cnt
|
||||||
|
label_save = label.encode()
|
||||||
|
image_key_save = 'image-%09d'.encode() % cnt
|
||||||
|
image_save = imgbuf
|
||||||
|
|
||||||
|
txn_save.put(label_key_save, label_save)
|
||||||
|
txn_save.put(image_key_save, image_save)
|
||||||
|
|
||||||
|
nSamples = cnt
|
||||||
|
with env_save.begin(write=True) as txn_save:
|
||||||
|
txn_save.put('num-samples'.encode(), str(nSamples).encode())
|
||||||
|
|
||||||
|
|
||||||
|
def analyze_lmdb_label_length(lmdb_path, batch_size=192, num_of_combinations=1000):
|
||||||
|
label_length_dict = {}
|
||||||
|
|
||||||
|
env = lmdb.open(lmdb_path, max_readers=32, readonly=True, lock=False, readahead=False, meminit=False)
|
||||||
|
with env.begin(write=False) as txn:
|
||||||
|
nSamples = int(txn.get('num-samples'.encode()))
|
||||||
|
nSamples = nSamples
|
||||||
|
|
||||||
|
for index in tqdm(range(nSamples)):
|
||||||
|
index += 1 # lmdb starts with 1
|
||||||
|
label_key = 'label-%09d'.encode() % index
|
||||||
|
label = txn.get(label_key).decode('utf-8')
|
||||||
|
|
||||||
|
label_length = len(label)
|
||||||
|
if label_length in label_length_dict:
|
||||||
|
label_length_dict[label_length] += 1
|
||||||
|
else:
|
||||||
|
label_length_dict[label_length] = 1
|
||||||
|
|
||||||
|
sorted_label_length = sorted(label_length_dict.items(), key=lambda x: x[1], reverse=True)
|
||||||
|
|
||||||
|
label_length_sum = 0
|
||||||
|
label_num = 0
|
||||||
|
lengths = []
|
||||||
|
p = []
|
||||||
|
for l, num in sorted_label_length:
|
||||||
|
label_length_sum += l * num
|
||||||
|
label_num += num
|
||||||
|
p.append(num)
|
||||||
|
lengths.append(l)
|
||||||
|
for i, _ in enumerate(p):
|
||||||
|
p[i] /= label_num
|
||||||
|
|
||||||
|
average_overall_length = int(label_length_sum / label_num * batch_size)
|
||||||
|
|
||||||
|
def get_combinations_of_fix_length(fix_length, items, p, batch_size):
|
||||||
|
ret = []
|
||||||
|
cur_sum = 0
|
||||||
|
ret = np.random.choice(items, batch_size - 1, True, p)
|
||||||
|
cur_sum = sum(ret)
|
||||||
|
ret = list(ret)
|
||||||
|
if fix_length - cur_sum in items:
|
||||||
|
ret.append(fix_length - cur_sum)
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
return ret
|
||||||
|
|
||||||
|
result = []
|
||||||
|
while len(result) < num_of_combinations:
|
||||||
|
ret = get_combinations_of_fix_length(average_overall_length, lengths, p, batch_size)
|
||||||
|
if ret is not None:
|
||||||
|
result.append(ret)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def generate_fix_shape_index_list(lmdb_path, combinations, pkl_save_path, num_of_iters=70000):
|
||||||
|
length_index_dict = {}
|
||||||
|
|
||||||
|
env = lmdb.open(lmdb_path, max_readers=32, readonly=True, lock=False, readahead=False, meminit=False)
|
||||||
|
with env.begin(write=False) as txn:
|
||||||
|
nSamples = int(txn.get('num-samples'.encode()))
|
||||||
|
nSamples = nSamples
|
||||||
|
|
||||||
|
for index in tqdm(range(nSamples)):
|
||||||
|
index += 1 # lmdb starts with 1
|
||||||
|
label_key = 'label-%09d'.encode() % index
|
||||||
|
label = txn.get(label_key).decode('utf-8')
|
||||||
|
|
||||||
|
label_length = len(label)
|
||||||
|
if label_length in length_index_dict:
|
||||||
|
length_index_dict[label_length].append(index)
|
||||||
|
else:
|
||||||
|
length_index_dict[label_length] = [index]
|
||||||
|
|
||||||
|
ret = []
|
||||||
|
for _ in range(num_of_iters):
|
||||||
|
comb = random.choice(combinations)
|
||||||
|
for l in comb:
|
||||||
|
ret.append(random.choice(length_index_dict[l]))
|
||||||
|
|
||||||
|
with open(pkl_save_path, 'wb') as f:
|
||||||
|
pickle.dump(ret, f, -1)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
# step 1: combine the SynthText dataset and MJSynth dataset into a single lmdb file
|
||||||
|
print('Begin to combine multiple lmdb datasets')
|
||||||
|
combine_lmdbs(['/home/workspace/mindspore_dataset/CNNCTC_Data/1_ST/',
|
||||||
|
'/home/workspace/mindspore_dataset/CNNCTC_Data/MJ_train/'],
|
||||||
|
'/home/workspace/mindspore_dataset/CNNCTC_Data/ST_MJ')
|
||||||
|
|
||||||
|
# step 2: generate the order of input data, guarantee that the input batch shape is fixed
|
||||||
|
print('Begin to generate the index order of input data')
|
||||||
|
combination = analyze_lmdb_label_length('/home/workspace/mindspore_dataset/CNNCTC_Data/ST_MJ')
|
||||||
|
generate_fix_shape_index_list('/home/workspace/mindspore_dataset/CNNCTC_Data/ST_MJ', combination,
|
||||||
|
'/home/workspace/mindspore_dataset/CNNCTC_Data/st_mj_fixed_length_index_list.pkl')
|
||||||
|
|
||||||
|
print('Done')
|
|
@ -0,0 +1,102 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""util file"""
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
class AverageMeter():
|
||||||
|
"""Computes and stores the average and current value"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.reset()
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
self.val = 0
|
||||||
|
self.avg = 0
|
||||||
|
self.sum = 0
|
||||||
|
self.count = 0
|
||||||
|
|
||||||
|
def update(self, val, n=1):
|
||||||
|
self.val = val
|
||||||
|
self.sum += val * n
|
||||||
|
self.count += n
|
||||||
|
self.avg = self.sum / self.count
|
||||||
|
|
||||||
|
|
||||||
|
class CTCLabelConverter():
|
||||||
|
""" Convert between text-label and text-index """
|
||||||
|
|
||||||
|
def __init__(self, character):
|
||||||
|
# character (str): set of the possible characters.
|
||||||
|
dict_character = list(character)
|
||||||
|
|
||||||
|
self.dict = {}
|
||||||
|
for i, char in enumerate(dict_character):
|
||||||
|
self.dict[char] = i
|
||||||
|
|
||||||
|
self.character = dict_character + ['[blank]'] # dummy '[blank]' token for CTCLoss (index 0)
|
||||||
|
self.dict['[blank]'] = len(dict_character)
|
||||||
|
|
||||||
|
def encode(self, text):
|
||||||
|
"""convert text-label into text-index.
|
||||||
|
input:
|
||||||
|
text: text labels of each image. [batch_size]
|
||||||
|
|
||||||
|
output:
|
||||||
|
text: concatenated text index for CTCLoss.
|
||||||
|
[sum(text_lengths)] = [text_index_0 + text_index_1 + ... + text_index_(n - 1)]
|
||||||
|
length: length of each text. [batch_size]
|
||||||
|
"""
|
||||||
|
length = [len(s) for s in text]
|
||||||
|
text = ''.join(text)
|
||||||
|
text = [self.dict[char] for char in text]
|
||||||
|
|
||||||
|
return np.array(text), np.array(length)
|
||||||
|
|
||||||
|
def decode(self, text_index, length):
|
||||||
|
""" convert text-index into text-label. """
|
||||||
|
texts = []
|
||||||
|
index = 0
|
||||||
|
for l in length:
|
||||||
|
t = text_index[index:index + l]
|
||||||
|
|
||||||
|
char_list = []
|
||||||
|
for i in range(l):
|
||||||
|
# if t[i] != self.dict['[blank]'] and (not (i > 0 and t[i - 1] == t[i])): # removing repeated characters and blank.
|
||||||
|
if t[i] != self.dict['[blank]'] and (
|
||||||
|
not (i > 0 and t[i - 1] == t[i])): # removing repeated characters and blank.
|
||||||
|
char_list.append(self.character[t[i]])
|
||||||
|
text = ''.join(char_list)
|
||||||
|
|
||||||
|
texts.append(text)
|
||||||
|
index += l
|
||||||
|
return texts
|
||||||
|
|
||||||
|
def reverse_encode(self, text_index, length):
|
||||||
|
""" convert text-index into text-label. """
|
||||||
|
texts = []
|
||||||
|
index = 0
|
||||||
|
for l in length:
|
||||||
|
t = text_index[index:index + l]
|
||||||
|
|
||||||
|
char_list = []
|
||||||
|
for i in range(l):
|
||||||
|
if t[i] != self.dict['[blank]']: # removing repeated characters and blank.
|
||||||
|
char_list.append(self.character[t[i]])
|
||||||
|
text = ''.join(char_list)
|
||||||
|
|
||||||
|
texts.append(text)
|
||||||
|
index += l
|
||||||
|
return texts
|
|
@ -0,0 +1,100 @@
|
||||||
|
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
"""cnnctc train"""
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import ast
|
||||||
|
|
||||||
|
import mindspore
|
||||||
|
from mindspore import context
|
||||||
|
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||||
|
from mindspore.dataset import GeneratorDataset
|
||||||
|
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig
|
||||||
|
from mindspore.train.model import Model
|
||||||
|
from mindspore.communication.management import init
|
||||||
|
from mindspore.common import set_seed
|
||||||
|
|
||||||
|
from src.config import Config_CNNCTC
|
||||||
|
from src.callback import LossCallBack
|
||||||
|
from src.dataset import ST_MJ_Generator_batch_fixed_length, ST_MJ_Generator_batch_fixed_length_para
|
||||||
|
from src.cnn_ctc import CNNCTC_Model, ctc_loss, WithLossCell
|
||||||
|
|
||||||
|
set_seed(1)
|
||||||
|
|
||||||
|
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False,
|
||||||
|
save_graphs_path=".", enable_auto_mixed_precision=False)
|
||||||
|
|
||||||
|
|
||||||
|
def dataset_creator(run_distribute):
|
||||||
|
if run_distribute:
|
||||||
|
st_dataset = ST_MJ_Generator_batch_fixed_length_para()
|
||||||
|
else:
|
||||||
|
st_dataset = ST_MJ_Generator_batch_fixed_length()
|
||||||
|
|
||||||
|
ds = GeneratorDataset(st_dataset,
|
||||||
|
['img', 'label_indices', 'text', 'sequence_length'],
|
||||||
|
num_parallel_workers=8)
|
||||||
|
|
||||||
|
return ds
|
||||||
|
|
||||||
|
|
||||||
|
def train(args_opt, config):
|
||||||
|
if args_opt.run_distribute:
|
||||||
|
init()
|
||||||
|
context.set_auto_parallel_context(parallel_mode="data_parallel")
|
||||||
|
|
||||||
|
ds = dataset_creator(args_opt.run_distribute)
|
||||||
|
|
||||||
|
net = CNNCTC_Model(config.NUM_CLASS, config.HIDDEN_SIZE, config.FINAL_FEATURE_WIDTH)
|
||||||
|
net.set_train(True)
|
||||||
|
|
||||||
|
if config.CKPT_PATH != '':
|
||||||
|
param_dict = load_checkpoint(config.CKPT_PATH)
|
||||||
|
load_param_into_net(net, param_dict)
|
||||||
|
print('parameters loaded!')
|
||||||
|
else:
|
||||||
|
print('train from scratch...')
|
||||||
|
|
||||||
|
criterion = ctc_loss()
|
||||||
|
opt = mindspore.nn.RMSProp(params=net.trainable_params(), centered=True, learning_rate=config.LR_PARA,
|
||||||
|
momentum=config.MOMENTUM, loss_scale=config.LOSS_SCALE)
|
||||||
|
|
||||||
|
net = WithLossCell(net, criterion)
|
||||||
|
loss_scale_manager = mindspore.train.loss_scale_manager.FixedLossScaleManager(config.LOSS_SCALE, False)
|
||||||
|
model = Model(net, optimizer=opt, loss_scale_manager=loss_scale_manager, amp_level="O2")
|
||||||
|
|
||||||
|
callback = LossCallBack()
|
||||||
|
config_ck = CheckpointConfig(save_checkpoint_steps=config.SAVE_CKPT_PER_N_STEP,
|
||||||
|
keep_checkpoint_max=config.KEEP_CKPT_MAX_NUM)
|
||||||
|
ckpoint_cb = ModelCheckpoint(prefix="CNNCTC", config=config_ck, directory=config.SAVE_PATH)
|
||||||
|
|
||||||
|
if args_opt.device_id == 0:
|
||||||
|
model.train(config.TRAIN_EPOCHS, ds, callbacks=[callback, ckpoint_cb], dataset_sink_mode=False)
|
||||||
|
else:
|
||||||
|
model.train(config.TRAIN_EPOCHS, ds, callbacks=[callback], dataset_sink_mode=False)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
parser = argparse.ArgumentParser(description='CNNCTC arg')
|
||||||
|
parser.add_argument('--device_id', type=int, default=0, help="Device id, default is 0.")
|
||||||
|
parser.add_argument("--ckpt_path", type=str, default="", help="Pretrain file path.")
|
||||||
|
parser.add_argument("--run_distribute", type=ast.literal_eval, default=False,
|
||||||
|
help="Run distribute, default is false.")
|
||||||
|
args_cfg = parser.parse_args()
|
||||||
|
|
||||||
|
cfg = Config_CNNCTC()
|
||||||
|
if args_cfg.ckpt_path != "":
|
||||||
|
cfg.CKPT_PATH = args_cfg.ckpt_path
|
||||||
|
train(args_cfg, cfg)
|
|
@ -11,7 +11,7 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
#" :===========================================================================
|
# ===========================================================================
|
||||||
"""
|
"""
|
||||||
network config setting, will be used in train.py and eval.py
|
network config setting, will be used in train.py and eval.py
|
||||||
"""
|
"""
|
||||||
|
|
Loading…
Reference in New Issue