diff --git a/model_zoo/official/cv/cnnctc/README.md b/model_zoo/official/cv/cnnctc/README.md new file mode 100644 index 00000000000..452c5a1d51c --- /dev/null +++ b/model_zoo/official/cv/cnnctc/README.md @@ -0,0 +1,354 @@ +# Contents + +- [CNNCTC Description](#CNNCTC-description) +- [Model Architecture](#model-architecture) +- [Dataset](#dataset) +- [Features](#features) + - [Mixed Precision](#mixed-precision) +- [Environment Requirements](#environment-requirements) +- [Quick Start](#quick-start) +- [Script Description](#script-description) + - [Script and Sample Code](#script-and-sample-code) + - [Script Parameters](#script-parameters) + - [Training Process](#training-process) + - [Training](#training) + - [Distributed Training](#distributed-training) + - [Evaluation Process](#evaluation-process) + - [Evaluation](#evaluation) +- [Model Description](#model-description) + - [Performance](#performance) + - [Evaluation Performance](#evaluation-performance) + - [Inference Performance](#evaluation-performance) + - [How to use](#how-to-use) + - [Inference](#inference) + - [Continue Training on the Pretrained Model](#continue-training-on-the-pretrained-model) + - [Transfer Learning](#transfer-learning) +- [Description of Random Situation](#description-of-random-situation) +- [ModelZoo Homepage](#modelzoo-homepage) + + +# [CNNCTC Description](#contents) +This paper proposes three major contributions to addresses scene text recognition (STR). +First, we examine the inconsistencies of training and evaluation datasets, and the performance gap results from inconsistencies. +Second, we introduce a unified four-stage STR framework that most existing STR models fit into. +Using this framework allows for the extensive evaluation of previously proposed STR modules and the discovery of previously +unexplored module combinations. Third, we analyze the module-wise contributions to performance in terms of accuracy, speed, +and memory demand, under one consistent set of training and evaluation datasets. Such analyses clean up the hindrance on the current +comparisons to understand the performance gain of the existing modules. +[Paper](https://arxiv.org/abs/1904.01906): J. Baek, G. Kim, J. Lee, S. Park, D. Han, S. Yun, S. J. Oh, and H. Lee, “What is wrong with scene text recognition model comparisons? dataset and model analysis,” ArXiv, vol. abs/1904.01906, 2019. + +# [Model Architecture](#contents) +This is an example of training CNN+CTC model for text recognition on MJSynth and SynthText dataset with MindSpore. + + + +# [Dataset](#contents) + +The [MJSynth](https://www.robots.ox.ac.uk/~vgg/data/text/) and [SynthText](https://github.com/ankush-me/SynthText) dataset are used for model training. The [The IIIT 5K-word dataset](https://cvit.iiit.ac.in/research/projects/cvit-projects/the-iiit-5k-word-dataset) dataset is used for evaluation. + +- step 1: +All the datasets have been preprocessed and stored in .lmdb format and can be downloaded [**HERE**](https://drive.google.com/drive/folders/192UfE9agQUMNq6AgU3_E05_FcPZK4hyt). + +- step 2: +Uncompress the downloaded file, rename the MJSynth dataset as MJ, the SynthText dataset as ST and the IIIT dataset as IIIT. + +- step 3: +Move above mentioned three datasets into `cnnctc_data` folder, and the structure should be as below: +``` +|--- CNNCTC/ + |--- cnnctc_data/ + |--- ST/ + data.mdb + lock.mdb + |--- MJ/ + data.mdb + lock.mdb + |--- IIIT/ + data.mdb + lock.mdb + + ...... +``` + +- step 4: +Preprocess the dataset by running: +``` +python src/preprocess_dataset.py +``` + +This takes around 75 minutes. + +# [Features](#contents) + +## Mixed Precision + +The [mixed precision](https://www.mindspore.cn/tutorial/zh-CN/master/advanced_use/mixed_precision.html) training method accelerates the deep learning neural network training process by using both the single-precision and half-precision data formats, and maintains the network precision achieved by the single-precision training at the same time. Mixed precision training can accelerate the computation process, reduce memory usage, and enable a larger model or batch size to be trained on specific hardware. +For FP16 operators, if the input data type is FP32, the backend of MindSpore will automatically handle it with reduced precision. Users could check the reduced-precision operators by enabling INFO log and then searching ‘reduce precision’. + + + +# [Environment Requirements](#contents) + +- Hardware(Ascend) + + - Prepare hardware environment with Ascend processor. If you want to try Ascend , please send the [application form](https://obs-9be7.obs.cn-east-2.myhuaweicloud.com/file/other/Ascend%20Model%20Zoo%E4%BD%93%E9%AA%8C%E8%B5%84%E6%BA%90%E7%94%B3%E8%AF%B7%E8%A1%A8.docx) to ascend@huawei.com. Once approved, you can get the resources. +- Framework + + - [MindSpore](https://www.mindspore.cn/install/en) +- For more information, please check the resources below: + - [MindSpore tutorials](https://www.mindspore.cn/tutorial/zh-CN/master/index.html) + + - [MindSpore API](https://www.mindspore.cn/api/zh-CN/master/index.html) + + + + +# [Quick Start](#contents) + +- Install dependencies: +``` +pip install lmdb +pip install Pillow +pip install tqdm +pip install six +``` + +- Standalone Training: +``` +bash scripts/run_standalone_train_ascend.sh $PRETRAINED_CKPT +``` + +- Distributed Training: +``` +bash scripts/run_distribute_train_ascend.sh $RANK_TABLE_FILE $PRETRAINED_CKPT +``` + +- Evaluation: +``` +bash scripts/run_eval_ascend.sh $TRAINED_CKPT +``` + +# [Script Description](#contents) + +## [Script and Sample Code](#contents) +The entire code structure is as following: +``` +|--- CNNCTC/ + |---README.md // descriptions about cnnctc + |---train.py // train scripts + |---eval.py // eval scripts + |---scripts + |---run_standalone_train_ascend.sh // shell script for standalone on ascend + |---run_distribute_train_ascend.sh // shell script for distributed on ascend + |---run_eval_ascend.sh // shell script for eval on ascend + |---src + |---__init__.py // init file + |---cnn_ctc.py // cnn_ctc network + |---config.py // total config + |---callback.py // loss callback file + |---dataset.py // process dataset + |---util.py // routine operation + |---generate_hccn_file.py // generate distribute json file + |---preprocess_dataset.py // preprocess dataset + +``` + + +## [Script Parameters](#contents) +Parameters for both training and evaluation can be set in `config.py`. + +Arguments: + * `--CHARACTER`: Character labels. + * `--NUM_CLASS`: The number of classes including all character labels and the label for CTCLoss. + * `--HIDDEN_SIZE`: Model hidden size. + * `--FINAL_FEATURE_WIDTH`: The number of features. + * `--IMG_H`: The height of input image. + * `--IMG_W`: The width of input image. + * `--TRAIN_DATASET_PATH`: The path to training dataset. + * `--TRAIN_DATASET_INDEX_PATH`: The path to training dataset index file which determines the order . + * `--TRAIN_BATCH_SIZE`: Training batch size. The batch size and index file must ensure input data is in fixed shape. + * `--TRAIN_DATASET_SIZE`: Training dataset size. + * `--TEST_DATASET_PATH`: The path to test dataset. + * `--TEST_BATCH_SIZE`: Test batch size. + * `--TEST_DATASET_SIZE`:Test dataset size. + * `--TRAIN_EPOCHS`:Total training epochs. + * `--CKPT_PATH`:The path to model checkpoint file, can be used to resume training and evaluation. + * `--SAVE_PATH`:The path to save model checkpoint file. + * `--LR`:Learning rate for standalone training. + * `--LR_PARA`:Learning rate for distributed training. + * `--MOMENTUM`:Momentum. + * `--LOSS_SCALE`:Loss scale to prevent gradient underflow. + * `--SAVE_CKPT_PER_N_STEP`:Save model checkpoint file per N steps. + * `--KEEP_CKPT_MAX_NUM`:The maximum number of saved model checkpoint file. + +## [Training Process](#contents) + +### Training + +- Standalone Training: +``` +bash scripts/run_standalone_train_ascend.sh $PRETRAINED_CKPT +``` + +Results and checkpoints are written to `./train` folder. Log can be found in `./train/log` and loss values are recorded in `./train/loss.log`. + +`$PRETRAINED_CKPT` is the path to model checkpoint and it is **optional**. If none is given the model will be trained from scratch. + +- Distributed Training: +``` +bash scripts/run_distribute_train_ascend.sh $RANK_TABLE_FILE $PRETRAINED_CKPT +``` + +Results and checkpoints are written to `./train_parallel_{i}` folder for device `i` respectively. + Log can be found in `./train_parallel_{i}/log_{i}.log` and loss values are recorded in `./train_parallel_{i}/loss.log`. + +`$RANK_TABLE_FILE` is needed when you are running a distribute task on ascend. +`$PATH_TO_CHECKPOINT` is the path to model checkpoint and it is **optional**. If none is given the model will be trained from scratch. + +### Training Result + +Training result will be stored in the example path, whose folder name begins with "train" or "train_parallel". You can find checkpoint file together with result like the followings in loss.log. + + +``` +# distribute training result(8p) +epoch: 1 step: 1 , loss is 76.25, average time per step is 0.335177839748392712 +epoch: 1 step: 2 , loss is 73.46875, average time per step is 0.36798572540283203 +epoch: 1 step: 3 , loss is 69.46875, average time per step is 0.3429678678512573 +epoch: 1 step: 4 , loss is 64.3125, average time per step is 0.33512671788533527 +epoch: 1 step: 5 , loss is 58.375, average time per step is 0.33149147033691406 +epoch: 1 step: 6 , loss is 52.7265625, average time per step is 0.3292975425720215 +... +epoch: 1 step: 8689 , loss is 9.706798802612482, average time per step is 0.3184656601312549 +epoch: 1 step: 8690 , loss is 9.70612545289855, average time per step is 0.3184725407765116 +epoch: 1 step: 8691 , loss is 9.70695776049204, average time per step is 0.31847309686135555 +epoch: 1 step: 8692 , loss is 9.707279624277456, average time per step is 0.31847339290613375 +epoch: 1 step: 8693 , loss is 9.70763437950938, average time per step is 0.3184720295013031 +epoch: 1 step: 8694 , loss is 9.707695425072046, average time per step is 0.31847410284595573 +epoch: 1 step: 8695 , loss is 9.708408273381295, average time per step is 0.31847338271072345 +epoch: 1 step: 8696 , loss is 9.708703753591953, average time per step is 0.3184726025560777 +epoch: 1 step: 8697 , loss is 9.709536406025824, average time per step is 0.31847212061114694 +epoch: 1 step: 8698 , loss is 9.708542263610315, average time per step is 0.3184715309307257 +``` + +## [Evaluation Process](#contents) + +### Evaluation +- Evaluation: +``` +bash scripts/run_eval_ascend.sh $TRAINED_CKPT +``` + +The model will be evaluated on the IIIT dataset, sample results and overall accuracy will be printed. + + +# [Model Description](#contents) +## [Performance](#contents) + +### Training Performance + +| Parameters | FasterRcnn | +| -------------------------- | ----------------------------------------------------------- | +| Model Version | V1 | +| Resource | Ascend 910 ;CPU 2.60GHz,192cores;Memory,755G | +| uploaded Date | 09/28/2020 (month/day/year) | +| MindSpore Version | 1.0.0 | +| Dataset | MJSynth,SynthText | +| Training Parameters | epoch=3, batch_size=192 | +| Optimizer | RMSProp | +| Loss Function | CTCLoss | +| Speed | 1pc: 300 ms/step; 8pcs: 310 ms/step | +| Total time | 1pc: 18 hours; 8pcs: 2.3 hours | +| Parameters (M) | 177 | +| Scripts | https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/cnnctc | + + +### Evaluation Performance + +| Parameters | FasterRcnn | +| ------------------- | --------------------------- | +| Model Version | V1 | +| Resource | Ascend 910 | +| Uploaded Date | 09/28/2020 (month/day/year) | +| MindSpore Version | 1.0.0 | +| Dataset | IIIT5K | +| batch_size | 192 | +| outputs | Accuracy | +| Accuracy | 85% | +| Model for inference | 675M (.ckpt file) | + +## [How to use](#contents) +### Inference + +If you need to use the trained model to perform inference on multiple hardware platforms, such as GPU, Ascend 910 or Ascend 310, you can refer to this [Link](https://www.mindspore.cn/tutorial/zh-CN/master/advanced_use/network_migration.html). Following the steps below, this is a simple example: + +- Running on Ascend + + ``` + # Set context + context.set_context(mode=context.GRAPH_HOME, device_target=cfg.device_target) + context.set_context(device_id=cfg.device_id) + + # Load unseen dataset for inference + dataset = dataset.create_dataset(cfg.data_path, 1, False) + + # Define model + net = CNNCTC(cfg.NUM_CLASS, cfg.HIDDEN_SIZE, cfg.FINAL_FEATURE_WIDTH) + opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 0.01, + cfg.momentum, weight_decay=cfg.weight_decay) + loss = P.CTCLoss(preprocess_collapse_repeated=False, + ctc_merge_repeated=True, + ignore_longer_outputs_than_inputs=False) + model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'}) + + # Load pre-trained model + param_dict = load_checkpoint(cfg.checkpoint_path) + load_param_into_net(net, param_dict) + net.set_train(False) + + # Make predictions on the unseen dataset + acc = model.eval(dataset) + print("accuracy: ", acc) + ``` + +### Continue Training on the Pretrained Model + +- running on Ascend + + ``` + # Load dataset + dataset = create_dataset(cfg.data_path, 1) + batch_num = dataset.get_dataset_size() + + # Define model + net = CNNCTC(cfg.NUM_CLASS, cfg.HIDDEN_SIZE, cfg.FINAL_FEATURE_WIDTH) + # Continue training if set pre_trained to be True + if cfg.pre_trained: + param_dict = load_checkpoint(cfg.checkpoint_path) + load_param_into_net(net, param_dict) + lr = lr_steps(0, lr_max=cfg.lr_init, total_epochs=cfg.epoch_size, + steps_per_epoch=batch_num) + opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), + Tensor(lr), cfg.momentum, weight_decay=cfg.weight_decay) + loss = P.CTCLoss(preprocess_collapse_repeated=False, + ctc_merge_repeated=True, + ignore_longer_outputs_than_inputs=False) + model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'}, + amp_level="O2", keep_batchnorm_fp32=False, loss_scale_manager=None) + + # Set callbacks + config_ck = CheckpointConfig(save_checkpoint_steps=batch_num * 5, + keep_checkpoint_max=cfg.keep_checkpoint_max) + time_cb = TimeMonitor(data_size=batch_num) + ckpoint_cb = ModelCheckpoint(prefix="train_googlenet_cifar10", directory="./", + config=config_ck) + loss_cb = LossMonitor() + + # Start training + model.train(cfg.epoch_size, dataset, callbacks=[time_cb, ckpoint_cb, loss_cb]) + print("train success") + ``` + + +# [ModelZoo Homepage](#contents) + Please check the official [homepage](https://gitee.com/mindspore/mindspore/tree/master/model_zoo). diff --git a/model_zoo/official/cv/cnnctc/eval.py b/model_zoo/official/cv/cnnctc/eval.py new file mode 100644 index 00000000000..e4c421bf4e8 --- /dev/null +++ b/model_zoo/official/cv/cnnctc/eval.py @@ -0,0 +1,109 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""cnnctc eval""" + +import argparse +import time +import numpy as np + +from mindspore import Tensor, context +import mindspore.common.dtype as mstype +from mindspore.train.serialization import load_checkpoint, load_param_into_net +from mindspore.dataset import GeneratorDataset + +from src.util import CTCLabelConverter, AverageMeter +from src.config import Config_CNNCTC +from src.dataset import IIIT_Generator_batch +from src.cnn_ctc import CNNCTC_Model + +context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False, + save_graphs_path=".", enable_auto_mixed_precision=False) + +def test_dataset_creator(): + ds = GeneratorDataset(IIIT_Generator_batch, ['img', 'label_indices', 'text', 'sequence_length', 'label_str']) + return ds + + +def test(config): + ds = test_dataset_creator() + + net = CNNCTC_Model(config.NUM_CLASS, config.HIDDEN_SIZE, config.FINAL_FEATURE_WIDTH) + + ckpt_path = config.CKPT_PATH + param_dict = load_checkpoint(ckpt_path) + load_param_into_net(net, param_dict) + print('parameters loaded! from: ', ckpt_path) + + converter = CTCLabelConverter(config.CHARACTER) + + model_run_time = AverageMeter() + npu_to_cpu_time = AverageMeter() + postprocess_time = AverageMeter() + + count = 0 + correct_count = 0 + for data in ds.create_tuple_iterator(): + img, _, text, _, length = data + + img_tensor = Tensor(img, mstype.float32) + + model_run_begin = time.time() + model_predict = net(img_tensor) + model_run_end = time.time() + model_run_time.update(model_run_end - model_run_begin) + + npu_to_cpu_begin = time.time() + model_predict = np.squeeze(model_predict.asnumpy()) + npu_to_cpu_end = time.time() + npu_to_cpu_time.update(npu_to_cpu_end - npu_to_cpu_begin) + + postprocess_begin = time.time() + preds_size = np.array([model_predict.shape[1]] * config.TEST_BATCH_SIZE) + preds_index = np.argmax(model_predict, 2) + preds_index = np.reshape(preds_index, [-1]) + preds_str = converter.decode(preds_index, preds_size) + postprocess_end = time.time() + postprocess_time.update(postprocess_end - postprocess_begin) + + label_str = converter.reverse_encode(text.asnumpy(), length.asnumpy()) + + if count == 0: + model_run_time.reset() + npu_to_cpu_time.reset() + postprocess_time.reset() + else: + print('---------model run time--------', model_run_time.avg) + print('---------npu_to_cpu run time--------', npu_to_cpu_time.avg) + print('---------postprocess run time--------', postprocess_time.avg) + + print("Prediction samples: \n", preds_str[:5]) + print("Ground truth: \n", label_str[:5]) + for pred, label in zip(preds_str, label_str): + if pred == label: + correct_count += 1 + count += 1 + + print('accuracy: ', correct_count / count) + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description="FasterRcnn training") + parser.add_argument('--device_id', type=int, default=0, help="Device id, default is 0.") + parser.add_argument("--ckpt_path", type=str, default="", help="trained file path.") + args_opt = parser.parse_args() + + cfg = Config_CNNCTC() + if args_opt.ckpt_path != "": + cfg.CKPT_PATH = args_opt.ckpt_path + test(cfg) diff --git a/model_zoo/official/cv/cnnctc/scripts/run_distribute_train_ascend.sh b/model_zoo/official/cv/cnnctc/scripts/run_distribute_train_ascend.sh new file mode 100644 index 00000000000..4d9b072be47 --- /dev/null +++ b/model_zoo/official/cv/cnnctc/scripts/run_distribute_train_ascend.sh @@ -0,0 +1,57 @@ +#!/bin/bash +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +current_exec_path=$(pwd) +echo ${current_exec_path} + +get_real_path(){ + if [ "${1:0:1}" == "/" ]; then + echo "$1" + else + echo "$(realpath -m $PWD/$1)" + fi +} + +PATH1=$(get_real_path $1) +echo $PATH1 + +PATH2=$(get_real_path $2) +echo $PATH2 + +python ${current_exec_path}/src/generate_hccn_file.py --rank_file=$PATH1 +export RANK_TABLE_FILE=$PATH1 +export RANK_SIZE=8 +ulimit -u unlimited +for((i=0;i<$RANK_SIZE;i++)); +do + rm ./train_parallel_$i/ -rf + mkdir ./train_parallel_$i + cp ./*.py ./train_parallel_$i + cp ./scripts/*.sh ./train_parallel_$i + cp -r ./src ./train_parallel_$i + cd ./train_parallel_$i || exit + export RANK_ID=$i + export DEVICE_ID=$i + echo "start training for rank $RANK_ID, device $DEVICE_ID" + if [ -f $PATH2 ] + then + python train.py --device_id=$i --ckpt_path=$PATH2 --run_distribute=True >log_$i.log 2>&1 & + else + python train.py --device_id=$i --run_distribute=True >log_$i.log 2>&1 & + fi + cd .. || exit +done + diff --git a/model_zoo/official/cv/cnnctc/scripts/run_eval_ascend.sh b/model_zoo/official/cv/cnnctc/scripts/run_eval_ascend.sh new file mode 100644 index 00000000000..1b93b0c4ff4 --- /dev/null +++ b/model_zoo/official/cv/cnnctc/scripts/run_eval_ascend.sh @@ -0,0 +1,54 @@ +#!/bin/bash +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +if [ $# -ne 1 ] +then + echo "Usage: sh run_eval_ascend.sh [TRAINED_CKPT]" +exit 1 +fi + +get_real_path(){ + if [ "${1:0:1}" == "/" ]; then + echo "$1" + else + echo "$(realpath -m $PWD/$1)" + fi +} + +PATH1=$(get_real_path $1) +echo $PATH1 +if [ ! -f $PATH1 ] +then + echo "error: TRAINED_CKPT=$PATH1 is not a file" +exit 1 +fi + +ulimit -u unlimited +export DEVICE_ID=0 + +if [ -d "eval" ]; +then + rm -rf ./eval +fi +mkdir ./eval +cp ./*.py ./eval +cp ./scripts/*.sh ./eval +cp -r ./src ./eval +cd ./eval || exit +echo "start infering for device $DEVICE_ID" +env > env.log +python eval.py --device_id=$DEVICE_ID --ckpt_path=$PATH1 &> log & +cd .. || exit diff --git a/model_zoo/official/cv/cnnctc/scripts/run_standalone_train_ascend.sh b/model_zoo/official/cv/cnnctc/scripts/run_standalone_train_ascend.sh new file mode 100644 index 00000000000..ffeeb5a2ac3 --- /dev/null +++ b/model_zoo/official/cv/cnnctc/scripts/run_standalone_train_ascend.sh @@ -0,0 +1,45 @@ +#!/bin/bash +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +get_real_path(){ + if [ "${1:0:1}" == "/" ]; then + echo "$1" + else + echo "$(realpath -m $PWD/$1)" + fi +} +PATH1=$(get_real_path $1) + +ulimit -u unlimited + +if [ -d "train" ]; +then + rm -rf ./train +fi +mkdir ./train +cp ./*.py ./train +cp ./scripts/*.sh ./train +cp -r ./src ./train +cd ./train || exit +echo "start training for device $DEVICE_ID" +env > env.log +if [ -f $PATH1 ] +then + python train.py --device_id=$DEVICE_ID --ckpt_path=$PATH1 --run_distribute=False &> log & +else + python train.py --device_id=$DEVICE_ID --run_distribute=False &> log & +fi +cd .. || exit diff --git a/model_zoo/official/cv/cnnctc/src/__init__.py b/model_zoo/official/cv/cnnctc/src/__init__.py new file mode 100644 index 00000000000..8d62ac3491e --- /dev/null +++ b/model_zoo/official/cv/cnnctc/src/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""src init file""" diff --git a/model_zoo/official/cv/cnnctc/src/callback.py b/model_zoo/official/cv/cnnctc/src/callback.py new file mode 100644 index 00000000000..6ebcfd784c3 --- /dev/null +++ b/model_zoo/official/cv/cnnctc/src/callback.py @@ -0,0 +1,71 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""loss callback""" + +import time +from mindspore.train.callback import Callback +from .util import AverageMeter + +class LossCallBack(Callback): + """ + Monitor the loss in training. + + If the loss is NAN or INF terminating training. + + Note: + If per_print_times is 0 do not print loss. + + Args: + per_print_times (int): Print loss every times. Default: 1. + """ + + def __init__(self, per_print_times=1): + super(LossCallBack, self).__init__() + if not isinstance(per_print_times, int) or per_print_times < 0: + raise ValueError("print_step must be int and >= 0.") + self._per_print_times = per_print_times + self.loss_avg = AverageMeter() + self.timer = AverageMeter() + self.start_time = time.time() + + def step_end(self, run_context): + cb_params = run_context.original_args() + + loss = cb_params.net_outputs.asnumpy() + + cur_step_in_epoch = (cb_params.cur_step_num - 1) % cb_params.batch_num + 1 + cur_num = cb_params.cur_step_num + + if cur_step_in_epoch % 2000 == 1: + self.loss_avg = AverageMeter() + self.timer = AverageMeter() + self.start_time = time.time() + else: + self.timer.update(time.time() - self.start_time) + self.start_time = time.time() + + self.loss_avg.update(loss) + + if self._per_print_times != 0 and cur_num % self._per_print_times == 0: + loss_file = open("./loss.log", "a+") + loss_file.write("epoch: %s step: %s , loss is %s, average time per step is %s" % ( + cb_params.cur_epoch_num, cur_step_in_epoch, + self.loss_avg.avg, self.timer.avg)) + loss_file.write("\n") + loss_file.close() + + print("epoch: %s step: %s , loss is %s, average time per step is %s" % ( + cb_params.cur_epoch_num, cur_step_in_epoch, + self.loss_avg.avg, self.timer.avg)) diff --git a/model_zoo/official/cv/cnnctc/src/cnn_ctc.py b/model_zoo/official/cv/cnnctc/src/cnn_ctc.py new file mode 100644 index 00000000000..89abbf90587 --- /dev/null +++ b/model_zoo/official/cv/cnnctc/src/cnn_ctc.py @@ -0,0 +1,255 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""cnn_ctc network define""" + +import mindspore.nn as nn +from mindspore.ops import operations as P +from mindspore.common.initializer import TruncatedNormal, initializer +import mindspore.common.dtype as mstype + +class CNNCTC_Model(nn.Cell): + + def __init__(self, num_class, hidden_size, final_feature_width): + super(CNNCTC_Model, self).__init__() + + self.num_class = num_class + self.hidden_size = hidden_size + self.final_feature_width = final_feature_width + + self.FeatureExtraction = ResNet_FeatureExtractor() + self.Prediction = nn.Dense(self.hidden_size, self.num_class) + + self.transpose = P.Transpose() + self.reshape = P.Reshape() + + def construct(self, x): + x = self.FeatureExtraction(x) + x = self.transpose(x, (0, 3, 1, 2)) # [b, c, h, w] -> [b, w, c, h] + + x = self.reshape(x, (-1, self.hidden_size)) + x = self.Prediction(x) + x = self.reshape(x, (-1, self.final_feature_width, self.num_class)) + + return x + + +class WithLossCell(nn.Cell): + + def __init__(self, backbone, loss_fn): + super(WithLossCell, self).__init__(auto_prefix=False) + self._backbone = backbone + self._loss_fn = loss_fn + + def construct(self, img, label_indices, text, sequence_length): + model_predict = self._backbone(img) + return self._loss_fn(model_predict, label_indices, text, sequence_length) + + @property + def backbone_network(self): + return self._backbone + +class ctc_loss(nn.Cell): + + def __init__(self): + super(ctc_loss, self).__init__() + + self.loss = P.CTCLoss(preprocess_collapse_repeated=False, + ctc_merge_repeated=True, + ignore_longer_outputs_than_inputs=False) + + self.mean = P.ReduceMean() + self.transpose = P.Transpose() + self.reshape = P.Reshape() + + def construct(self, inputs, labels_indices, labels_values, sequence_length): + inputs = self.transpose(inputs, (1, 0, 2)) + + loss, _ = self.loss(inputs, labels_indices, labels_values, sequence_length) + + loss = self.mean(loss) + return loss + + +class ResNet_FeatureExtractor(nn.Cell): + def __init__(self): + super(ResNet_FeatureExtractor, self).__init__() + self.ConvNet = ResNet(3, 512, BasicBlock, [1, 2, 5, 3]) + + def construct(self, featuremap): + return self.ConvNet(featuremap) + + +class ResNet(nn.Cell): + def __init__(self, input_channel, output_channel, block, layers): + super(ResNet, self).__init__() + + self.output_channel_block = [int(output_channel / 4), int(output_channel / 2), output_channel, output_channel] + + self.inplanes = int(output_channel / 8) + self.conv0_1 = ms_conv3x3(input_channel, int(output_channel / 16), stride=1, padding=1, pad_mode='pad') + self.bn0_1 = ms_fused_bn(int(output_channel / 16)) + self.conv0_2 = ms_conv3x3(int(output_channel / 16), self.inplanes, stride=1, padding=1, pad_mode='pad') + self.bn0_2 = ms_fused_bn(self.inplanes) + self.relu = P.ReLU() + + self.maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2, pad_mode='valid') + self.layer1 = self._make_layer(block, self.output_channel_block[0], layers[0]) + self.conv1 = ms_conv3x3(self.output_channel_block[0], self.output_channel_block[0], stride=1, padding=1, + pad_mode='pad') + self.bn1 = ms_fused_bn(self.output_channel_block[0]) + + self.maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2, pad_mode='valid') + self.layer2 = self._make_layer(block, self.output_channel_block[1], layers[1]) + self.conv2 = ms_conv3x3(self.output_channel_block[1], self.output_channel_block[1], stride=1, padding=1, + pad_mode='pad') + self.bn2 = ms_fused_bn(self.output_channel_block[1]) + + self.pad = P.Pad(((0, 0), (0, 0), (0, 0), (1, 1))) + self.maxpool3 = nn.MaxPool2d(kernel_size=2, stride=(2, 1), pad_mode='valid') + self.layer3 = self._make_layer(block, self.output_channel_block[2], layers[2]) + self.conv3 = ms_conv3x3(self.output_channel_block[2], self.output_channel_block[2], stride=1, padding=1, + pad_mode='pad') + self.bn3 = ms_fused_bn(self.output_channel_block[2]) + + self.layer4 = self._make_layer(block, self.output_channel_block[3], layers[3]) + self.conv4_1 = ms_conv2x2(self.output_channel_block[3], self.output_channel_block[3], stride=(2, 1), + pad_mode='valid') + self.bn4_1 = ms_fused_bn(self.output_channel_block[3]) + + self.conv4_2 = ms_conv2x2(self.output_channel_block[3], self.output_channel_block[3], stride=1, padding=0, + pad_mode='valid') + self.bn4_2 = ms_fused_bn(self.output_channel_block[3]) + + def _make_layer(self, block, planes, blocks, stride=1): + downsample = None + if stride != 1 or self.inplanes != planes * block.expansion: + downsample = nn.SequentialCell( + [ms_conv1x1(self.inplanes, planes * block.expansion, stride=stride), + ms_fused_bn(planes * block.expansion)] + ) + + layers = [] + layers.append(block(self.inplanes, planes, stride, downsample)) + self.inplanes = planes * block.expansion + for _ in range(1, blocks): + layers.append(block(self.inplanes, planes)) + + return nn.SequentialCell(layers) + + def construct(self, x): + x = self.conv0_1(x) + x = self.bn0_1(x) + x = self.relu(x) + x = self.conv0_2(x) + x = self.bn0_2(x) + x = self.relu(x) + + x = self.maxpool1(x) + x = self.layer1(x) + x = self.conv1(x) + x = self.bn1(x) + x = self.relu(x) + + x = self.maxpool2(x) + x = self.layer2(x) + x = self.conv2(x) + x = self.bn2(x) + x = self.relu(x) + + x = self.pad(x) + x = self.maxpool3(x) + x = self.layer3(x) + x = self.conv3(x) + x = self.bn3(x) + x = self.relu(x) + + x = self.layer4(x) + x = self.pad(x) + x = self.conv4_1(x) + x = self.bn4_1(x) + x = self.relu(x) + x = self.conv4_2(x) + x = self.bn4_2(x) + x = self.relu(x) + + return x + + +class BasicBlock(nn.Cell): + expansion = 1 + + def __init__(self, inplanes, planes, stride=1, downsample=None): + super(BasicBlock, self).__init__() + + self.conv1 = ms_conv3x3(inplanes, planes, stride=stride, padding=1, pad_mode='pad') + self.bn1 = ms_fused_bn(planes) + self.conv2 = ms_conv3x3(planes, planes, stride=stride, padding=1, pad_mode='pad') + self.bn2 = ms_fused_bn(planes) + self.relu = P.ReLU() + self.downsample = downsample + self.add = P.TensorAdd() + + def construct(self, x): + residual = x + + out = self.conv1(x) + out = self.bn1(out) + out = self.relu(out) + + out = self.conv2(out) + out = self.bn2(out) + + if self.downsample is not None: + residual = self.downsample(x) + out = self.add(out, residual) + out = self.relu(out) + + return out + + +def weight_variable(shape, factor=0.1, half_precision=False): + if half_precision: + return initializer(TruncatedNormal(0.02), shape, dtype=mstype.float16) + + return TruncatedNormal(0.02) + + +def ms_conv3x3(in_channels, out_channels, stride=1, padding=0, pad_mode='same', has_bias=False): + """Get a conv2d layer with 3x3 kernel size.""" + init_value = weight_variable((out_channels, in_channels, 3, 3)) + return nn.Conv2d(in_channels, out_channels, + kernel_size=3, stride=stride, padding=padding, pad_mode=pad_mode, weight_init=init_value, + has_bias=has_bias) + + +def ms_conv1x1(in_channels, out_channels, stride=1, padding=0, pad_mode='same', has_bias=False): + """Get a conv2d layer with 1x1 kernel size.""" + init_value = weight_variable((out_channels, in_channels, 1, 1)) + return nn.Conv2d(in_channels, out_channels, + kernel_size=1, stride=stride, padding=padding, pad_mode=pad_mode, weight_init=init_value, + has_bias=has_bias) + + +def ms_conv2x2(in_channels, out_channels, stride=1, padding=0, pad_mode='same', has_bias=False): + """Get a conv2d layer with 2x2 kernel size.""" + init_value = weight_variable((out_channels, in_channels, 1, 1)) + return nn.Conv2d(in_channels, out_channels, + kernel_size=2, stride=stride, padding=padding, pad_mode=pad_mode, weight_init=init_value, + has_bias=has_bias) + + +def ms_fused_bn(channels, momentum=0.1): + """Get a fused batchnorm""" + return nn.BatchNorm2d(channels, momentum=momentum) diff --git a/model_zoo/official/cv/cnnctc/src/config.py b/model_zoo/official/cv/cnnctc/src/config.py new file mode 100644 index 00000000000..1b8c7755b8f --- /dev/null +++ b/model_zoo/official/cv/cnnctc/src/config.py @@ -0,0 +1,43 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""network config setting, will be used in train.py and eval.py""" + +class Config_CNNCTC(): + # model config + CHARACTER = '0123456789abcdefghijklmnopqrstuvwxyz' + NUM_CLASS = len(CHARACTER) + 1 + HIDDEN_SIZE = 512 + FINAL_FEATURE_WIDTH = 26 + + # dataset config + IMG_H = 32 + IMG_W = 100 + TRAIN_DATASET_PATH = 'CNNCTC_Data/ST_MJ/' + TRAIN_DATASET_INDEX_PATH = 'CNNCTC_Data/st_mj_fixed_length_index_list.pkl' + TRAIN_BATCH_SIZE = 192 + TEST_DATASET_PATH = 'CNNCTC_Data/IIIT5k_3000' + TEST_BATCH_SIZE = 256 + TEST_DATASET_SIZE = 2976 + TRAIN_EPOCHS = 3 + + # training config + CKPT_PATH = '' + SAVE_PATH = './' + LR = 1e-4 + LR_PARA = 5e-4 + MOMENTUM = 0.8 + LOSS_SCALE = 8096 + SAVE_CKPT_PER_N_STEP = 2000 + KEEP_CKPT_MAX_NUM = 5 diff --git a/model_zoo/official/cv/cnnctc/src/dataset.py b/model_zoo/official/cv/cnnctc/src/dataset.py new file mode 100644 index 00000000000..475c6c997fd --- /dev/null +++ b/model_zoo/official/cv/cnnctc/src/dataset.py @@ -0,0 +1,265 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""cnn_ctc dataset""" + +import sys +import pickle +import math +import six +import numpy as np +from PIL import Image +import lmdb + +from mindspore.communication.management import get_rank, get_group_size + +from .util import CTCLabelConverter +from .config import Config_CNNCTC + +config = Config_CNNCTC() + +class NormalizePAD(): + + def __init__(self, max_size, PAD_type='right'): + self.max_size = max_size + self.PAD_type = PAD_type + + def __call__(self, img): + # toTensor + img = np.array(img, dtype=np.float32) + img = img.transpose([2, 0, 1]) + img = img.astype(np.float) + img = np.true_divide(img, 255) + # normalize + img = np.subtract(img, 0.5) + img = np.true_divide(img, 0.5) + + _, _, w = img.shape + Pad_img = np.zeros(shape=self.max_size, dtype=np.float32) + Pad_img[:, :, :w] = img # right pad + if self.max_size[2] != w: # add border Pad + Pad_img[:, :, w:] = np.tile(np.expand_dims(img[:, :, w - 1], 2), (1, 1, self.max_size[2] - w)) + + return Pad_img + + +class AlignCollate(): + + def __init__(self, imgH=32, imgW=100): + self.imgH = imgH + self.imgW = imgW + + def __call__(self, images): + + resized_max_w = self.imgW + input_channel = 3 + transform = NormalizePAD((input_channel, self.imgH, resized_max_w)) + + resized_images = [] + for image in images: + w, h = image.size + ratio = w / float(h) + if math.ceil(self.imgH * ratio) > self.imgW: + resized_w = self.imgW + else: + resized_w = math.ceil(self.imgH * ratio) + + resized_image = image.resize((resized_w, self.imgH), Image.BICUBIC) + resized_images.append(transform(resized_image)) + + image_tensors = np.concatenate([np.expand_dims(t, 0) for t in resized_images], 0) + + return image_tensors + + +def get_img_from_lmdb(env, index): + with env.begin(write=False) as txn: + label_key = 'label-%09d'.encode() % index + label = txn.get(label_key).decode('utf-8') + img_key = 'image-%09d'.encode() % index + imgbuf = txn.get(img_key) + + buf = six.BytesIO() + buf.write(imgbuf) + buf.seek(0) + try: + img = Image.open(buf).convert('RGB') # for color image + + except IOError: + print(f'Corrupted image for {index}') + # make dummy image and dummy label for corrupted image. + img = Image.new('RGB', (config.IMG_W, config.IMG_H)) + label = '[dummy_label]' + + label = label.lower() + + return img, label + + +class ST_MJ_Generator_batch_fixed_length: + def __init__(self): + self.align_collector = AlignCollate() + self.converter = CTCLabelConverter(config.CHARACTER) + self.env = lmdb.open(config.TRAIN_DATASET_PATH, max_readers=32, readonly=True, lock=False, readahead=False, + meminit=False) + if not self.env: + print('cannot create lmdb from %s' % (config.TRAIN_DATASET_PATH)) + raise ValueError(config.TRAIN_DATASET_PATH) + + with open(config.TRAIN_DATASET_INDEX_PATH, 'rb') as f: + self.st_mj_filtered_index_list = pickle.load(f) + + print(f'num of samples in ST_MJ dataset: {len(self.st_mj_filtered_index_list)}') + self.dataset_size = len(self.st_mj_filtered_index_list) // config.TRAIN_BATCH_SIZE + self.batch_size = config.TRAIN_BATCH_SIZE + + def __len__(self): + return self.dataset_size + + def __getitem__(self, item): + img_ret = [] + text_ret = [] + + for i in range(item * self.batch_size, (item + 1) * self.batch_size): + index = self.st_mj_filtered_index_list[i] + img, label = get_img_from_lmdb(self.env, index) + + img_ret.append(img) + text_ret.append(label) + + img_ret = self.align_collector(img_ret) + text_ret, length = self.converter.encode(text_ret) + + label_indices = [] + for i, _ in enumerate(length): + for j in range(length[i]): + label_indices.append((i, j)) + label_indices = np.array(label_indices, np.int64) + sequence_length = np.array([config.FINAL_FEATURE_WIDTH] * config.TRAIN_BATCH_SIZE, dtype=np.int32) + text_ret = text_ret.astype(np.int32) + + return img_ret, label_indices, text_ret, sequence_length + +class ST_MJ_Generator_batch_fixed_length_para: + def __init__(self): + self.align_collector = AlignCollate() + self.converter = CTCLabelConverter(config.CHARACTER) + self.env = lmdb.open(config.TRAIN_DATASET_PATH, max_readers=32, readonly=True, lock=False, readahead=False, + meminit=False) + if not self.env: + print('cannot create lmdb from %s' % (config.TRAIN_DATASET_PATH)) + raise ValueError(config.TRAIN_DATASET_PATH) + + with open(config.TRAIN_DATASET_INDEX_PATH, 'rb') as f: + self.st_mj_filtered_index_list = pickle.load(f) + + print(f'num of samples in ST_MJ dataset: {len(self.st_mj_filtered_index_list)}') + self.rank_id = get_rank() + self.rank_size = get_group_size() + self.dataset_size = len(self.st_mj_filtered_index_list) // config.TRAIN_BATCH_SIZE // self.rank_size + self.batch_size = config.TRAIN_BATCH_SIZE + + def __len__(self): + return self.dataset_size + + def __getitem__(self, item): + img_ret = [] + text_ret = [] + + rank_item = (item * self.rank_size) + self.rank_id + for i in range(rank_item * self.batch_size, (rank_item + 1) * self.batch_size): + index = self.st_mj_filtered_index_list[i] + img, label = get_img_from_lmdb(self.env, index) + + img_ret.append(img) + text_ret.append(label) + + img_ret = self.align_collector(img_ret) + text_ret, length = self.converter.encode(text_ret) + + label_indices = [] + for i, _ in enumerate(length): + for j in range(length[i]): + label_indices.append((i, j)) + label_indices = np.array(label_indices, np.int64) + sequence_length = np.array([config.FINAL_FEATURE_WIDTH] * config.TRAIN_BATCH_SIZE, dtype=np.int32) + text_ret = text_ret.astype(np.int32) + + return img_ret, label_indices, text_ret, sequence_length + + +def IIIT_Generator_batch(): + max_len = int((26 + 1) // 2) + + align_collector = AlignCollate() + + converter = CTCLabelConverter(config.CHARACTER) + + env = lmdb.open(config.TEST_DATASET_PATH, max_readers=32, readonly=True, lock=False, readahead=False, meminit=False) + if not env: + print('cannot create lmdb from %s' % (config.TEST_DATASET_PATH)) + sys.exit(0) + + with env.begin(write=False) as txn: + nSamples = int(txn.get('num-samples'.encode())) + nSamples = nSamples + + # Filtering + filtered_index_list = [] + for index in range(nSamples): + index += 1 # lmdb starts with 1 + label_key = 'label-%09d'.encode() % index + label = txn.get(label_key).decode('utf-8') + + if len(label) > max_len: + continue + + illegal_sample = False + for char_item in label.lower(): + if char_item not in config.CHARACTER: + illegal_sample = True + break + if illegal_sample: + continue + + filtered_index_list.append(index) + + img_ret = [] + text_ret = [] + + print(f'num of samples in IIIT dataset: {len(filtered_index_list)}') + + for index in filtered_index_list: + + img, label = get_img_from_lmdb(env, index) + + img_ret.append(img) + text_ret.append(label) + + if len(img_ret) == config.TEST_BATCH_SIZE: + img_ret = align_collector(img_ret) + text_ret, length = converter.encode(text_ret) + + label_indices = [] + for i, _ in enumerate(length): + for j in range(length[i]): + label_indices.append((i, j)) + label_indices = np.array(label_indices, np.int64) + sequence_length = np.array([26] * config.TEST_BATCH_SIZE, dtype=np.int32) + text_ret = text_ret.astype(np.int32) + + yield img_ret, label_indices, text_ret, sequence_length, length + + img_ret = [] + text_ret = [] diff --git a/model_zoo/official/cv/cnnctc/src/generate_hccn_file.py b/model_zoo/official/cv/cnnctc/src/generate_hccn_file.py new file mode 100644 index 00000000000..6c0dfef14a0 --- /dev/null +++ b/model_zoo/official/cv/cnnctc/src/generate_hccn_file.py @@ -0,0 +1,88 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""generate ascend rank file""" + +import os +import socket +import argparse + +parser = argparse.ArgumentParser(description="ascend distribute rank.") +parser.add_argument("--rank_file", type=str, default="scripts/rank_table_8p.json", help="rank_tabel_file_path.") + +def main(rank_table_file): + nproc_per_node = 8 + + visible_devices = ['0', '1', '2', '3', '4', '5', '6', '7'] + + server_id = socket.gethostbyname(socket.gethostname()) + + hccn_configs = open('/etc/hccn.conf', 'r').readlines() + device_ips = {} + for hccn_item in hccn_configs: + hccn_item = hccn_item.strip() + if hccn_item.startswith('address_'): + device_id, device_ip = hccn_item.split('=') + device_id = device_id.split('_')[1] + device_ips[device_id] = device_ip + print('device_id:{}, device_ip:{}'.format(device_id, device_ip)) + + hccn_table = {} + hccn_table['board_id'] = '0x002f' # A+K + # hccn_table['board_id'] = '0x0000' # A+X + + hccn_table['chip_info'] = '910' + hccn_table['deploy_mode'] = 'lab' + hccn_table['group_count'] = '1' + hccn_table['group_list'] = [] + instance_list = [] + for instance_id in range(nproc_per_node): + instance = {} + instance['devices'] = [] + device_id = visible_devices[instance_id] + device_ip = device_ips[device_id] + instance['devices'].append({ + 'device_id': device_id, + 'device_ip': device_ip, + }) + instance['rank_id'] = str(instance_id) + instance['server_id'] = server_id + instance_list.append(instance) + hccn_table['group_list'].append({ + 'device_num': str(nproc_per_node), + 'server_num': '1', + 'group_name': '', + 'instance_count': str(nproc_per_node), + 'instance_list': instance_list, + }) + hccn_table['para_plane_nic_location'] = 'device' + hccn_table['para_plane_nic_name'] = [] + for instance_id in range(nproc_per_node): + eth_id = visible_devices[instance_id] + hccn_table['para_plane_nic_name'].append('eth{}'.format(eth_id)) + hccn_table['para_plane_nic_num'] = str(nproc_per_node) + hccn_table['status'] = 'completed' + import json + with open(rank_table_file, 'w') as table_fp: + json.dump(hccn_table, table_fp, indent=4) + +if __name__ == '__main__': + args_opt = parser.parse_args() + rank_table = args_opt.rank_file + if os.path.exists(rank_table): + print('Rank table file exists.') + else: + print('Generating rank table file.') + main(rank_table) + print('Rank table file generated') diff --git a/model_zoo/official/cv/cnnctc/src/preprocess_dataset.py b/model_zoo/official/cv/cnnctc/src/preprocess_dataset.py new file mode 100644 index 00000000000..91392dfe35d --- /dev/null +++ b/model_zoo/official/cv/cnnctc/src/preprocess_dataset.py @@ -0,0 +1,171 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""preprocess dataset""" + +import random +import pickle +import numpy as np +import lmdb +from tqdm import tqdm + +def combine_lmdbs(lmdb_paths, lmdb_save_path): + max_len = int((26 + 1) // 2) + character = '0123456789abcdefghijklmnopqrstuvwxyz' + + env_save = lmdb.open( + lmdb_save_path, + map_size=1099511627776) + + cnt = 0 + for lmdb_path in lmdb_paths: + env = lmdb.open(lmdb_path, max_readers=32, readonly=True, lock=False, readahead=False, meminit=False) + with env.begin(write=False) as txn: + nSamples = int(txn.get('num-samples'.encode())) + nSamples = nSamples + + # Filtering + for index in tqdm(range(nSamples)): + index += 1 # lmdb starts with 1 + label_key = 'label-%09d'.encode() % index + label = txn.get(label_key).decode('utf-8') + + if len(label) > max_len: + continue + + illegal_sample = False + for char_item in label.lower(): + if char_item not in character: + illegal_sample = True + break + if illegal_sample: + continue + + img_key = 'image-%09d'.encode() % index + imgbuf = txn.get(img_key) + + with env_save.begin(write=True) as txn_save: + cnt += 1 + + label_key_save = 'label-%09d'.encode() % cnt + label_save = label.encode() + image_key_save = 'image-%09d'.encode() % cnt + image_save = imgbuf + + txn_save.put(label_key_save, label_save) + txn_save.put(image_key_save, image_save) + + nSamples = cnt + with env_save.begin(write=True) as txn_save: + txn_save.put('num-samples'.encode(), str(nSamples).encode()) + + +def analyze_lmdb_label_length(lmdb_path, batch_size=192, num_of_combinations=1000): + label_length_dict = {} + + env = lmdb.open(lmdb_path, max_readers=32, readonly=True, lock=False, readahead=False, meminit=False) + with env.begin(write=False) as txn: + nSamples = int(txn.get('num-samples'.encode())) + nSamples = nSamples + + for index in tqdm(range(nSamples)): + index += 1 # lmdb starts with 1 + label_key = 'label-%09d'.encode() % index + label = txn.get(label_key).decode('utf-8') + + label_length = len(label) + if label_length in label_length_dict: + label_length_dict[label_length] += 1 + else: + label_length_dict[label_length] = 1 + + sorted_label_length = sorted(label_length_dict.items(), key=lambda x: x[1], reverse=True) + + label_length_sum = 0 + label_num = 0 + lengths = [] + p = [] + for l, num in sorted_label_length: + label_length_sum += l * num + label_num += num + p.append(num) + lengths.append(l) + for i, _ in enumerate(p): + p[i] /= label_num + + average_overall_length = int(label_length_sum / label_num * batch_size) + + def get_combinations_of_fix_length(fix_length, items, p, batch_size): + ret = [] + cur_sum = 0 + ret = np.random.choice(items, batch_size - 1, True, p) + cur_sum = sum(ret) + ret = list(ret) + if fix_length - cur_sum in items: + ret.append(fix_length - cur_sum) + else: + return None + return ret + + result = [] + while len(result) < num_of_combinations: + ret = get_combinations_of_fix_length(average_overall_length, lengths, p, batch_size) + if ret is not None: + result.append(ret) + return result + + +def generate_fix_shape_index_list(lmdb_path, combinations, pkl_save_path, num_of_iters=70000): + length_index_dict = {} + + env = lmdb.open(lmdb_path, max_readers=32, readonly=True, lock=False, readahead=False, meminit=False) + with env.begin(write=False) as txn: + nSamples = int(txn.get('num-samples'.encode())) + nSamples = nSamples + + for index in tqdm(range(nSamples)): + index += 1 # lmdb starts with 1 + label_key = 'label-%09d'.encode() % index + label = txn.get(label_key).decode('utf-8') + + label_length = len(label) + if label_length in length_index_dict: + length_index_dict[label_length].append(index) + else: + length_index_dict[label_length] = [index] + + ret = [] + for _ in range(num_of_iters): + comb = random.choice(combinations) + for l in comb: + ret.append(random.choice(length_index_dict[l])) + + with open(pkl_save_path, 'wb') as f: + pickle.dump(ret, f, -1) + + +if __name__ == '__main__': + # step 1: combine the SynthText dataset and MJSynth dataset into a single lmdb file + print('Begin to combine multiple lmdb datasets') + combine_lmdbs(['/home/workspace/mindspore_dataset/CNNCTC_Data/1_ST/', + '/home/workspace/mindspore_dataset/CNNCTC_Data/MJ_train/'], + '/home/workspace/mindspore_dataset/CNNCTC_Data/ST_MJ') + + # step 2: generate the order of input data, guarantee that the input batch shape is fixed + print('Begin to generate the index order of input data') + combination = analyze_lmdb_label_length('/home/workspace/mindspore_dataset/CNNCTC_Data/ST_MJ') + generate_fix_shape_index_list('/home/workspace/mindspore_dataset/CNNCTC_Data/ST_MJ', combination, + '/home/workspace/mindspore_dataset/CNNCTC_Data/st_mj_fixed_length_index_list.pkl') + + print('Done') diff --git a/model_zoo/official/cv/cnnctc/src/util.py b/model_zoo/official/cv/cnnctc/src/util.py new file mode 100644 index 00000000000..ac3d98c68a5 --- /dev/null +++ b/model_zoo/official/cv/cnnctc/src/util.py @@ -0,0 +1,102 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""util file""" + +import numpy as np + +class AverageMeter(): + """Computes and stores the average and current value""" + + def __init__(self): + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + + +class CTCLabelConverter(): + """ Convert between text-label and text-index """ + + def __init__(self, character): + # character (str): set of the possible characters. + dict_character = list(character) + + self.dict = {} + for i, char in enumerate(dict_character): + self.dict[char] = i + + self.character = dict_character + ['[blank]'] # dummy '[blank]' token for CTCLoss (index 0) + self.dict['[blank]'] = len(dict_character) + + def encode(self, text): + """convert text-label into text-index. + input: + text: text labels of each image. [batch_size] + + output: + text: concatenated text index for CTCLoss. + [sum(text_lengths)] = [text_index_0 + text_index_1 + ... + text_index_(n - 1)] + length: length of each text. [batch_size] + """ + length = [len(s) for s in text] + text = ''.join(text) + text = [self.dict[char] for char in text] + + return np.array(text), np.array(length) + + def decode(self, text_index, length): + """ convert text-index into text-label. """ + texts = [] + index = 0 + for l in length: + t = text_index[index:index + l] + + char_list = [] + for i in range(l): + # if t[i] != self.dict['[blank]'] and (not (i > 0 and t[i - 1] == t[i])): # removing repeated characters and blank. + if t[i] != self.dict['[blank]'] and ( + not (i > 0 and t[i - 1] == t[i])): # removing repeated characters and blank. + char_list.append(self.character[t[i]]) + text = ''.join(char_list) + + texts.append(text) + index += l + return texts + + def reverse_encode(self, text_index, length): + """ convert text-index into text-label. """ + texts = [] + index = 0 + for l in length: + t = text_index[index:index + l] + + char_list = [] + for i in range(l): + if t[i] != self.dict['[blank]']: # removing repeated characters and blank. + char_list.append(self.character[t[i]]) + text = ''.join(char_list) + + texts.append(text) + index += l + return texts diff --git a/model_zoo/official/cv/cnnctc/train.py b/model_zoo/official/cv/cnnctc/train.py new file mode 100644 index 00000000000..a85f484e411 --- /dev/null +++ b/model_zoo/official/cv/cnnctc/train.py @@ -0,0 +1,100 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""cnnctc train""" + +import argparse +import ast + +import mindspore +from mindspore import context +from mindspore.train.serialization import load_checkpoint, load_param_into_net +from mindspore.dataset import GeneratorDataset +from mindspore.train.callback import ModelCheckpoint, CheckpointConfig +from mindspore.train.model import Model +from mindspore.communication.management import init +from mindspore.common import set_seed + +from src.config import Config_CNNCTC +from src.callback import LossCallBack +from src.dataset import ST_MJ_Generator_batch_fixed_length, ST_MJ_Generator_batch_fixed_length_para +from src.cnn_ctc import CNNCTC_Model, ctc_loss, WithLossCell + +set_seed(1) + +context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", save_graphs=False, + save_graphs_path=".", enable_auto_mixed_precision=False) + + +def dataset_creator(run_distribute): + if run_distribute: + st_dataset = ST_MJ_Generator_batch_fixed_length_para() + else: + st_dataset = ST_MJ_Generator_batch_fixed_length() + + ds = GeneratorDataset(st_dataset, + ['img', 'label_indices', 'text', 'sequence_length'], + num_parallel_workers=8) + + return ds + + +def train(args_opt, config): + if args_opt.run_distribute: + init() + context.set_auto_parallel_context(parallel_mode="data_parallel") + + ds = dataset_creator(args_opt.run_distribute) + + net = CNNCTC_Model(config.NUM_CLASS, config.HIDDEN_SIZE, config.FINAL_FEATURE_WIDTH) + net.set_train(True) + + if config.CKPT_PATH != '': + param_dict = load_checkpoint(config.CKPT_PATH) + load_param_into_net(net, param_dict) + print('parameters loaded!') + else: + print('train from scratch...') + + criterion = ctc_loss() + opt = mindspore.nn.RMSProp(params=net.trainable_params(), centered=True, learning_rate=config.LR_PARA, + momentum=config.MOMENTUM, loss_scale=config.LOSS_SCALE) + + net = WithLossCell(net, criterion) + loss_scale_manager = mindspore.train.loss_scale_manager.FixedLossScaleManager(config.LOSS_SCALE, False) + model = Model(net, optimizer=opt, loss_scale_manager=loss_scale_manager, amp_level="O2") + + callback = LossCallBack() + config_ck = CheckpointConfig(save_checkpoint_steps=config.SAVE_CKPT_PER_N_STEP, + keep_checkpoint_max=config.KEEP_CKPT_MAX_NUM) + ckpoint_cb = ModelCheckpoint(prefix="CNNCTC", config=config_ck, directory=config.SAVE_PATH) + + if args_opt.device_id == 0: + model.train(config.TRAIN_EPOCHS, ds, callbacks=[callback, ckpoint_cb], dataset_sink_mode=False) + else: + model.train(config.TRAIN_EPOCHS, ds, callbacks=[callback], dataset_sink_mode=False) + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='CNNCTC arg') + parser.add_argument('--device_id', type=int, default=0, help="Device id, default is 0.") + parser.add_argument("--ckpt_path", type=str, default="", help="Pretrain file path.") + parser.add_argument("--run_distribute", type=ast.literal_eval, default=False, + help="Run distribute, default is false.") + args_cfg = parser.parse_args() + + cfg = Config_CNNCTC() + if args_cfg.ckpt_path != "": + cfg.CKPT_PATH = args_cfg.ckpt_path + train(args_cfg, cfg) diff --git a/model_zoo/official/cv/faster_rcnn/src/config.py b/model_zoo/official/cv/faster_rcnn/src/config.py index 24523793a46..28123399e85 100644 --- a/model_zoo/official/cv/faster_rcnn/src/config.py +++ b/model_zoo/official/cv/faster_rcnn/src/config.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -#" :=========================================================================== +# =========================================================================== """ network config setting, will be used in train.py and eval.py """