From 03c57a1e8b72d0996b9429da92dde829c3f0c2ad Mon Sep 17 00:00:00 2001 From: gengdongjie Date: Mon, 29 Jun 2020 22:01:52 +0800 Subject: [PATCH] add warpctc to modelzoo --- mindspore/ops/_grad/grad_nn_ops.py | 2 +- .../tbe/basic_lstm_cell_c_state_grad.py | 4 +- model_zoo/warpctc/README.md | 137 ++++++++++++++++++ model_zoo/warpctc/eval.py | 65 +++++++++ model_zoo/warpctc/process_data.py | 71 +++++++++ .../warpctc/scripts/run_distribute_train.sh | 62 ++++++++ model_zoo/warpctc/scripts/run_eval.sh | 60 ++++++++ model_zoo/warpctc/scripts/run_process_data.sh | 20 +++ .../warpctc/scripts/run_standalone_train.sh | 54 +++++++ model_zoo/warpctc/src/config.py | 31 ++++ model_zoo/warpctc/src/dataset.py | 92 ++++++++++++ model_zoo/warpctc/src/loss.py | 49 +++++++ model_zoo/warpctc/src/lr_schedule.py | 36 +++++ model_zoo/warpctc/src/metric.py | 89 ++++++++++++ model_zoo/warpctc/src/warpctc.py | 90 ++++++++++++ model_zoo/warpctc/src/warpctc_for_train.py | 114 +++++++++++++++ model_zoo/warpctc/train.py | 84 +++++++++++ 17 files changed, 1057 insertions(+), 3 deletions(-) create mode 100644 model_zoo/warpctc/README.md create mode 100755 model_zoo/warpctc/eval.py create mode 100755 model_zoo/warpctc/process_data.py create mode 100755 model_zoo/warpctc/scripts/run_distribute_train.sh create mode 100755 model_zoo/warpctc/scripts/run_eval.sh create mode 100755 model_zoo/warpctc/scripts/run_process_data.sh create mode 100755 model_zoo/warpctc/scripts/run_standalone_train.sh create mode 100755 model_zoo/warpctc/src/config.py create mode 100755 model_zoo/warpctc/src/dataset.py create mode 100755 model_zoo/warpctc/src/loss.py create mode 100755 model_zoo/warpctc/src/lr_schedule.py create mode 100755 model_zoo/warpctc/src/metric.py create mode 100755 model_zoo/warpctc/src/warpctc.py create mode 100755 model_zoo/warpctc/src/warpctc_for_train.py create mode 100755 model_zoo/warpctc/train.py diff --git a/mindspore/ops/_grad/grad_nn_ops.py b/mindspore/ops/_grad/grad_nn_ops.py index 00b9e3051be..107de1768cc 100755 --- a/mindspore/ops/_grad/grad_nn_ops.py +++ b/mindspore/ops/_grad/grad_nn_ops.py @@ -716,7 +716,7 @@ def get_bprop_basic_lstm_cell(self): def bprop(x, h, c, w, b, out, dout): _, _, it, jt, ft, ot, tanhct = out dct, dht, _, _, _, _, _ = dout - dgate, dct_1 = basic_lstm_cell_cstate_grad(c, dht, dct, it, ft, jt, ot, tanhct) + dgate, dct_1 = basic_lstm_cell_cstate_grad(c, dht, dct, it, jt, ft, ot, tanhct) dxt, dht = basic_lstm_cell_input_grad(dgate, w) dw, db = basic_lstm_cell_weight_grad(F.depend(x, dxt), h, dgate) return dxt, dht, dct_1, dw, db diff --git a/mindspore/ops/_op_impl/tbe/basic_lstm_cell_c_state_grad.py b/mindspore/ops/_op_impl/tbe/basic_lstm_cell_c_state_grad.py index 440b1ce2c71..1e42c1d6fe2 100644 --- a/mindspore/ops/_op_impl/tbe/basic_lstm_cell_c_state_grad.py +++ b/mindspore/ops/_op_impl/tbe/basic_lstm_cell_c_state_grad.py @@ -29,8 +29,8 @@ basic_lstm_cell_c_state_grad_op_info = TBERegOp("BasicLSTMCellCStateGrad") \ .input(1, "dht", False, "required", "all") \ .input(2, "dct", False, "required", "all") \ .input(3, "it", False, "required", "all") \ - .input(4, "ft", False, "required", "all") \ - .input(5, "jt", False, "required", "all") \ + .input(4, "jt", False, "required", "all") \ + .input(5, "ft", False, "required", "all") \ .input(6, "ot", False, "required", "all") \ .input(7, "tanhct", False, "required", "all") \ .output(0, "dgate", False, "required", "all") \ diff --git a/model_zoo/warpctc/README.md b/model_zoo/warpctc/README.md new file mode 100644 index 00000000000..cb941255bfb --- /dev/null +++ b/model_zoo/warpctc/README.md @@ -0,0 +1,137 @@ +# Warpctc Example + +## Description + +These is an example of training Warpctc with self-generated captcha image dataset in MindSpore. + +## Requirements + +- Install [MindSpore](https://www.mindspore.cn/install/en). + +- Generate captcha images. + +> The [captcha](https://github.com/lepture/captcha) library can be used to generate captcha images. You can generate the train and test dataset by yourself or just run the script `scripts/run_process_data.sh`. By default, the shell script will generate 10000 test images and 50000 train images separately. +> ``` +> $ cd scripts +> $ sh run_process_data.sh +> +> # after execution, you will find the dataset like the follows: +> . +> └─warpctc +> └─data +> ├─ train # train dataset +> └─ test # evaluate dataset +> ... + + +## Structure + +```shell +. +└──warpct + ├── README.md + ├── script + ├── run_distribute_train.sh # launch distributed training(8 pcs) + ├── run_eval.sh # launch evaluation + ├── run_process_data.sh # launch dataset generation + └── run_standalone_train.sh # launch standalone training(1 pcs) + ├── src + ├── config.py # parameter configuration + ├── dataset.py # data preprocessing + ├── loss.py # ctcloss definition + ├── lr_generator.py # generate learning rate for each step + ├── metric.py # accuracy metric for warpctc network + ├── warpctc.py # warpctc network definition + └── warpctc_for_train.py # warp network with grad, loss and gradient clip + ├── eval.py # eval net + ├── process_data.py # dataset generation script + └── train.py # train net +``` + + +## Parameter configuration + +Parameters for both training and evaluation can be set in config.py. + +``` +"max_captcha_digits": 4, # max number of digits in each +"captcha_width": 160, # width of captcha images +"captcha_height": 64, # height of capthca images +"batch_size": 64, # batch size of input tensor +"epoch_size": 30, # only valid for taining, which is always 1 for inference +"hidden_size": 512, # hidden size in LSTM layers +"learning_rate": 0.01, # initial learning rate +"momentum": 0.9 # momentum of SGD optimizer +"save_checkpoint": True, # whether save checkpoint or not +"save_checkpoint_steps": 98, # the step interval between two checkpoints. By default, the last checkpoint will be saved after the last step +"keep_checkpoint_max": 30, # only keep the last keep_checkpoint_max checkpoint +"save_checkpoint_path": "./", # path to save checkpoint +``` + +## Running the example + +### Train + +#### Usage + +``` +# distributed training +Usage: sh run_distribute_train.sh [MINDSPORE_HCCL_CONFIG_PATH] [DATASET_PATH] + +# standalone training +Usage: sh run_standalone_train.sh [DATASET_PATH] +``` + + +#### Launch + +``` +# distribute training example +sh run_distribute_train.sh rank_table.json ../data/train + +# standalone training example +sh run_standalone_train.sh ../data/train +``` + +> About rank_table.json, you can refer to the [distributed training tutorial](https://www.mindspore.cn/tutorial/en/master/advanced_use/distributed_training.html). + +#### Result + +Training result will be stored in folder `scripts`, whose name begins with "train" or "train_parallel". Under this, you can find checkpoint file together with result like the followings in log. + +``` +# distribute training result(8 pcs) +Epoch: [ 1/ 30], step: [ 98/ 98], loss: [0.5853/0.5853], time: [376813.7944] +Epoch: [ 2/ 30], step: [ 98/ 98], loss: [0.4007/0.4007], time: [75882.0951] +Epoch: [ 3/ 30], step: [ 98/ 98], loss: [0.0921/0.0921], time: [75150.9385] +Epoch: [ 4/ 30], step: [ 98/ 98], loss: [0.1472/0.1472], time: [75135.0193] +Epoch: [ 5/ 30], step: [ 98/ 98], loss: [0.0186/0.0186], time: [75199.5809] +... +``` + + +### Evaluation + +#### Usage + +``` +# evaluation +Usage: sh run_eval.sh [DATASET_PATH] [CHECKPOINT_PATH] +``` + +#### Launch + +``` +# evaluation example +sh run_eval.sh ../data/test warpctc-30-98.ckpt +``` + +> checkpoint can be produced in training process. + +#### Result + +Evaluation result will be stored in the example path, whose folder name is "eval". Under this, you can find result like the followings in log. + +``` +result: {'WarpCTCAccuracy': 0.9901472929936306} +``` diff --git a/model_zoo/warpctc/eval.py b/model_zoo/warpctc/eval.py new file mode 100755 index 00000000000..df62c7c7551 --- /dev/null +++ b/model_zoo/warpctc/eval.py @@ -0,0 +1,65 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Warpctc evaluation""" +import os +import math as m +import random +import argparse +import numpy as np +from mindspore import context +from mindspore import dataset as de +from mindspore.train.model import Model +from mindspore.train.serialization import load_checkpoint, load_param_into_net + +from src.loss import CTCLoss +from src.config import config as cf +from src.dataset import create_dataset +from src.warpctc import StackedRNN +from src.metric import WarpCTCAccuracy + +random.seed(1) +np.random.seed(1) +de.config.set_seed(1) + +parser = argparse.ArgumentParser(description="Warpctc training") +parser.add_argument("--dataset_path", type=str, default=None, help="Dataset, default is None.") +parser.add_argument("--checkpoint_path", type=str, default=None, help="checkpoint file path, default is None") +args_opt = parser.parse_args() + +device_id = int(os.getenv('DEVICE_ID')) +context.set_context(mode=context.GRAPH_MODE, + device_target="Ascend", + save_graphs=False, + device_id=device_id) + +if __name__ == '__main__': + max_captcha_digits = cf.max_captcha_digits + input_size = m.ceil(cf.captcha_height / 64) * 64 * 3 + # create dataset + dataset = create_dataset(dataset_path=args_opt.dataset_path, repeat_num=1, batch_size=cf.batch_size) + step_size = dataset.get_dataset_size() + # define loss + loss = CTCLoss(max_sequence_length=cf.captcha_width, max_label_length=max_captcha_digits, batch_size=cf.batch_size) + # define net + net = StackedRNN(input_size=input_size, batch_size=cf.batch_size, hidden_size=cf.hidden_size) + # load checkpoint + param_dict = load_checkpoint(args_opt.checkpoint_path) + load_param_into_net(net, param_dict) + net.set_train(False) + # define model + model = Model(net, loss_fn=loss, metrics={'WarpCTCAccuracy': WarpCTCAccuracy()}) + # start evaluation + res = model.eval(dataset) + print("result:", res, flush=True) diff --git a/model_zoo/warpctc/process_data.py b/model_zoo/warpctc/process_data.py new file mode 100755 index 00000000000..567ad109336 --- /dev/null +++ b/model_zoo/warpctc/process_data.py @@ -0,0 +1,71 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Generate train and test dataset""" +import os +import math as m +import random +from multiprocessing import Process +from captcha.image import ImageCaptcha + + +def _generate_captcha_per_process(path, total, start, end, img_width, img_height, max_digits): + captcha = ImageCaptcha(width=img_width, height=img_height) + filename_head = '{:0>' + str(len(str(total))) + '}-' + for i in range(start, end): + digits = '' + digits_length = random.randint(1, max_digits) + for _ in range(0, digits_length): + integer = random.randint(0, 9) + digits += str(integer) + captcha.write(digits, os.path.join(path, filename_head.format(i) + digits + '.png')) + + +def generate_captcha(name, img_num, img_width, img_height, max_digits, process_num=16): + """ + generate captcha images + + Args: + name(str): name of folder, under which captcha images are saved in + img_num(int): number of generated captcha images + img_width(int): width of generated captcha images + img_height(int): height of generated captcha images + max_digits(int): max number of digits in each captcha images. For each captcha images, number of digits is in + range [1,max_digits] + process_num(int): number of process to generate captcha images, default is 16 + """ + cur_script_path = os.path.dirname(os.path.realpath(__file__)) + path = os.path.join(cur_script_path, "data", name) + print("Generating dataset [{}] under {}...".format(name, path)) + if os.path.exists(path): + os.system("rm -rf {}".format(path)) + os.system("mkdir -p {}".format(path)) + img_num_per_thread = m.ceil(img_num / process_num) + + processes = [] + for i in range(process_num): + start = i * img_num_per_thread + end = start + img_num_per_thread if i != (process_num - 1) else img_num + p = Process(target=_generate_captcha_per_process, + args=(path, img_num, start, end, img_width, img_height, max_digits)) + p.start() + processes.append(p) + for p in processes: + p.join() + print("Generating dataset [{}] finished, total number is {}!".format(name, img_num)) + + +if __name__ == '__main__': + generate_captcha("test", img_num=10000, img_width=160, img_height=64, max_digits=4) + generate_captcha("train", img_num=50000, img_width=160, img_height=64, max_digits=4) diff --git a/model_zoo/warpctc/scripts/run_distribute_train.sh b/model_zoo/warpctc/scripts/run_distribute_train.sh new file mode 100755 index 00000000000..3cebf6d195f --- /dev/null +++ b/model_zoo/warpctc/scripts/run_distribute_train.sh @@ -0,0 +1,62 @@ +#!/bin/bash +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +if [ $# != 2 ]; then + echo "Usage: sh run_distribute_train.sh [MINDSPORE_HCCL_CONFIG_PATH] [DATASET_PATH]" + exit 1 +fi + +get_real_path() { + if [ "${1:0:1}" == "/" ]; then + echo "$1" + else + echo "$(realpath -m $PWD/$1)" + fi +} + +PATH1=$(get_real_path $1) +PATH2=$(get_real_path $2) + +if [ ! -f $PATH1 ]; then + echo "error: MINDSPORE_HCCL_CONFIG_PATH=$PATH1 is not a file" + exit 1 +fi + +if [ ! -d $PATH2 ]; then + echo "error: DATASET_PATH=$PATH2 is not a directory" + exit 1 +fi + +ulimit -u unlimited +export DEVICE_NUM=8 +export RANK_SIZE=8 +export MINDSPORE_HCCL_CONFIG_PATH=$PATH1 +export RANK_TABLE_FILE=$PATH1 + +for ((i = 0; i < ${DEVICE_NUM}; i++)); do + export DEVICE_ID=$i + export RANK_ID=$i + rm -rf ./train_parallel$i + mkdir ./train_parallel$i + cp ../*.py ./train_parallel$i + cp *.sh ./train_parallel$i + cp -r ../src ./train_parallel$i + cd ./train_parallel$i || exit + echo "start training for rank $RANK_ID, device $DEVICE_ID" + env >env.log + python train.py --run_distribute=True --device_num=$DEVICE_NUM --dataset_path=$PATH2 &>log & + cd .. +done diff --git a/model_zoo/warpctc/scripts/run_eval.sh b/model_zoo/warpctc/scripts/run_eval.sh new file mode 100755 index 00000000000..659de6d72a3 --- /dev/null +++ b/model_zoo/warpctc/scripts/run_eval.sh @@ -0,0 +1,60 @@ +#!/bin/bash +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +if [ $# != 2 ]; then + echo "Usage: sh run_eval.sh [DATASET_PATH] [CHECKPOINT_PATH]" + exit 1 +fi + +get_real_path() { + if [ "${1:0:1}" == "/" ]; then + echo "$1" + else + echo "$(realpath -m $PWD/$1)" + fi +} + +PATH1=$(get_real_path $1) +PATH2=$(get_real_path $2) + +if [ ! -d $PATH1 ]; then + echo "error: DATASET_PATH=$PATH1 is not a directory" + exit 1 +fi + +if [ ! -f $PATH2 ]; then + echo "error: CHECKPOINT_PATH=$PATH2 is not a file" + exit 1 +fi + +ulimit -u unlimited +export DEVICE_NUM=1 +export DEVICE_ID=0 +export RANK_SIZE=$DEVICE_NUM +export RANK_ID=0 + +if [ -d "eval" ]; then + rm -rf ./eval +fi +mkdir ./eval +cp ../*.py ./eval +cp *.sh ./eval +cp -r ../src ./eval +cd ./eval || exit +env >env.log +echo "start evaluation for device $DEVICE_ID" +python eval.py --dataset_path=$PATH1 --checkpoint_path=$PATH2 &>log & +cd .. diff --git a/model_zoo/warpctc/scripts/run_process_data.sh b/model_zoo/warpctc/scripts/run_process_data.sh new file mode 100755 index 00000000000..56b89f1a72c --- /dev/null +++ b/model_zoo/warpctc/scripts/run_process_data.sh @@ -0,0 +1,20 @@ +#!/bin/bash +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +CUR_PATH=$(dirname $PWD/$0) +cd $CUR_PATH/../ && + python process_data.py && + cd - || exit \ No newline at end of file diff --git a/model_zoo/warpctc/scripts/run_standalone_train.sh b/model_zoo/warpctc/scripts/run_standalone_train.sh new file mode 100755 index 00000000000..22a16ef4c82 --- /dev/null +++ b/model_zoo/warpctc/scripts/run_standalone_train.sh @@ -0,0 +1,54 @@ +#!/bin/bash +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +if [ $# != 1 ]; then + echo "Usage: sh run_standalone_train.sh [DATASET_PATH]" + exit 1 +fi + +get_real_path() { + if [ "${1:0:1}" == "/" ]; then + echo "$1" + else + echo "$(realpath -m $PWD/$1)" + fi +} + +PATH1=$(get_real_path $1) + +if [ ! -d $PATH1 ]; then + echo "error: DATASET_PATH=$PATH1 is not a directory" + exit 1 +fi + +ulimit -u unlimited +export DEVICE_NUM=1 +export DEVICE_ID=0 +export RANK_ID=0 +export RANK_SIZE=1 + +if [ -d "train" ]; then + rm -rf ./train +fi +mkdir ./train +cp ../*.py ./train +cp *.sh ./train +cp -r ../src ./train +cd ./train || exit +echo "start training for device $DEVICE_ID" +env >env.log +python train.py --dataset=$PATH1 &>log & +cd .. diff --git a/model_zoo/warpctc/src/config.py b/model_zoo/warpctc/src/config.py new file mode 100755 index 00000000000..ed9c2968de0 --- /dev/null +++ b/model_zoo/warpctc/src/config.py @@ -0,0 +1,31 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Network parameters.""" +from easydict import EasyDict + +config = EasyDict({ + "max_captcha_digits": 4, + "captcha_width": 160, + "captcha_height": 64, + "batch_size": 64, + "epoch_size": 30, + "hidden_size": 512, + "learning_rate": 0.01, + "momentum": 0.9, + "save_checkpoint": True, + "save_checkpoint_steps": 98, + "keep_checkpoint_max": 30, + "save_checkpoint_path": "./", +}) diff --git a/model_zoo/warpctc/src/dataset.py b/model_zoo/warpctc/src/dataset.py new file mode 100755 index 00000000000..76e592b906f --- /dev/null +++ b/model_zoo/warpctc/src/dataset.py @@ -0,0 +1,92 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Dataset preprocessing.""" +import os +import math as m +import numpy as np +import mindspore.common.dtype as mstype +import mindspore.dataset.engine as de +import mindspore.dataset.transforms.c_transforms as c +import mindspore.dataset.transforms.vision.c_transforms as vc +from PIL import Image +from src.config import config as cf + + +class _CaptchaDataset(): + """ + create train or evaluation dataset for warpctc + + Args: + img_root_dir(str): root path of images + max_captcha_digits(int): max number of digits in images. + blank(int): value reserved for blank label, default is 10. When parsing label from image file names, if label + length is less than max_captcha_digits, the remaining labels are padding with blank. + """ + + def __init__(self, img_root_dir, max_captcha_digits, blank=10): + if not os.path.exists(img_root_dir): + raise RuntimeError("the input image dir {} is invalid!".format(img_root_dir)) + self.img_root_dir = img_root_dir + self.img_names = [i for i in os.listdir(img_root_dir) if i.endswith('.png')] + self.max_captcha_digits = max_captcha_digits + self.blank = blank + + def __len__(self): + return len(self.img_names) + + def __getitem__(self, item): + img_name = self.img_names[item] + im = Image.open(os.path.join(self.img_root_dir, img_name)) + r, g, b = im.split() + im = Image.merge("RGB", (b, g, r)) + image = np.array(im) + label_str = os.path.splitext(img_name)[0] + label_str = label_str[label_str.find('-') + 1:] + label = [int(i) for i in label_str] + label.extend([int(self.blank)] * (self.max_captcha_digits - len(label))) + label = np.array(label) + return image, label + + +def create_dataset(dataset_path, repeat_num=1, batch_size=1): + """ + create train or evaluation dataset for warpctc + + Args: + dataset_path(int): dataset path + repeat_num(int): dataset repetition num, default is 1 + batch_size(int): batch size of generated dataset, default is 1 + """ + rank_size = int(os.environ.get("RANK_SIZE")) if os.environ.get("RANK_SIZE") else 1 + rank_id = int(os.environ.get("RANK_ID")) if os.environ.get("RANK_ID") else 0 + + dataset = _CaptchaDataset(dataset_path, cf.max_captcha_digits) + ds = de.GeneratorDataset(dataset, ["image", "label"], shuffle=True, num_shards=rank_size, shard_id=rank_id) + ds.set_dataset_size(m.ceil(len(dataset) / rank_size)) + image_trans = [ + vc.Rescale(1.0 / 255.0, 0.0), + vc.Normalize([0.9010, 0.9049, 0.9025], std=[0.1521, 0.1347, 0.1458]), + vc.Resize((m.ceil(cf.captcha_height / 16) * 16, cf.captcha_width)), + vc.HWC2CHW() + ] + label_trans = [ + c.TypeCast(mstype.int32) + ] + ds = ds.map(input_columns=["image"], num_parallel_workers=8, operations=image_trans) + ds = ds.map(input_columns=["label"], num_parallel_workers=8, operations=label_trans) + + ds = ds.batch(batch_size) + ds = ds.repeat(repeat_num) + return ds diff --git a/model_zoo/warpctc/src/loss.py b/model_zoo/warpctc/src/loss.py new file mode 100755 index 00000000000..8ea4c20e94a --- /dev/null +++ b/model_zoo/warpctc/src/loss.py @@ -0,0 +1,49 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""CTC Loss.""" +import numpy as np +from mindspore.nn.loss.loss import _Loss +from mindspore import Tensor, Parameter +from mindspore.common import dtype as mstype +from mindspore.ops import operations as P + + +class CTCLoss(_Loss): + """ + CTCLoss definition + + Args: + max_sequence_length(int): max number of sequence length. For captcha images, the value is equal to image + width + max_label_length(int): max number of label length for each input. + batch_size(int): batch size of input logits + """ + + def __init__(self, max_sequence_length, max_label_length, batch_size): + super(CTCLoss, self).__init__() + self.sequence_length = Parameter(Tensor(np.array([max_sequence_length] * batch_size), mstype.int32), + name="sequence_length") + labels_indices = [] + for i in range(batch_size): + for j in range(max_label_length): + labels_indices.append([i, j]) + self.labels_indices = Parameter(Tensor(np.array(labels_indices), mstype.int64), name="labels_indices") + self.reshape = P.Reshape() + self.ctc_loss = P.CTCLoss(ctc_merge_repeated=True) + + def construct(self, logit, label): + labels_values = self.reshape(label, (-1,)) + loss, _ = self.ctc_loss(logit, self.labels_indices, labels_values, self.sequence_length) + return loss diff --git a/model_zoo/warpctc/src/lr_schedule.py b/model_zoo/warpctc/src/lr_schedule.py new file mode 100755 index 00000000000..a0ae6c886ad --- /dev/null +++ b/model_zoo/warpctc/src/lr_schedule.py @@ -0,0 +1,36 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Learning rate generator.""" + + +def get_lr(epoch_size, step_size, lr_init): + """ + generate learning rate for each step, which decays in every 10 epoch + + Args: + epoch_size(int): total epoch number + step_size(int): total step number in each step + lr_init(int): initial learning rate + + Returns: + List, learning rate array + """ + lr = lr_init + lrs = [] + for i in range(1, epoch_size + 1): + if i % 10 == 0: + lr *= 0.1 + lrs.extend([lr for _ in range(step_size)]) + return lrs diff --git a/model_zoo/warpctc/src/metric.py b/model_zoo/warpctc/src/metric.py new file mode 100755 index 00000000000..d1060d0781b --- /dev/null +++ b/model_zoo/warpctc/src/metric.py @@ -0,0 +1,89 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Metric for accuracy evaluation.""" +from mindspore import nn + +BLANK_LABLE = 10 + + +class WarpCTCAccuracy(nn.Metric): + """ + Define accuracy metric for warpctc network. + """ + + def __init__(self): + super(WarpCTCAccuracy).__init__() + self._correct_num = 0 + self._total_num = 0 + self._count = 0 + + def clear(self): + self._correct_num = 0 + self._total_num = 0 + + def update(self, *inputs): + if len(inputs) != 2: + raise ValueError('WarpCTCAccuracy need 2 inputs (y_pred, y), but got {}'.format(len(inputs))) + + y_pred = self._convert_data(inputs[0]) + y = self._convert_data(inputs[1]) + + self._count += 1 + + pred_lbls = self._get_prediction(y_pred) + + for b_idx, target in enumerate(y): + if self._is_eq(pred_lbls[b_idx], target): + self._correct_num += 1 + self._total_num += 1 + + def eval(self): + if self._total_num == 0: + raise RuntimeError('Accuary can not be calculated, because the number of samples is 0.') + return self._correct_num / self._total_num + + @staticmethod + def _is_eq(pred_lbl, target): + """ + check whether predict label is equal to target label + """ + target = target.tolist() + pred_diff = len(target) - len(pred_lbl) + if pred_diff > 0: + # padding by BLANK_LABLE + pred_lbl.extend([BLANK_LABLE] * pred_diff) + return pred_lbl == target + + @staticmethod + def _get_prediction(y_pred): + """ + parse predict result to labels + """ + seq_len, batch_size, _ = y_pred.shape + indices = y_pred.argmax(axis=2) + + lens = [seq_len] * batch_size + pred_lbls = [] + for i in range(batch_size): + idx = indices[:, i] + last_idx = BLANK_LABLE + pred_lbl = [] + for j in range(lens[i]): + cur_idx = idx[j] + if cur_idx not in [last_idx, BLANK_LABLE]: + pred_lbl.append(cur_idx) + last_idx = cur_idx + pred_lbls.append(pred_lbl) + return pred_lbls diff --git a/model_zoo/warpctc/src/warpctc.py b/model_zoo/warpctc/src/warpctc.py new file mode 100755 index 00000000000..9669fc4bfd5 --- /dev/null +++ b/model_zoo/warpctc/src/warpctc.py @@ -0,0 +1,90 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Warpctc network definition.""" + +import numpy as np +import mindspore.nn as nn +from mindspore import Tensor, Parameter +from mindspore.common import dtype as mstype +from mindspore.ops import operations as P +from mindspore.ops import functional as F + + +class StackedRNN(nn.Cell): + """ + Define a stacked RNN network which contains two LSTM layers and one full-connect layer. + + Args: + input_size(int): Size of time sequence. Usually, the input_size is equal to three times of image height for + captcha images. + batch_size(int): batch size of input data, default is 64 + hidden_size(int): the hidden size in LSTM layers, default is 512 + """ + def __init__(self, input_size, batch_size=64, hidden_size=512): + super(StackedRNN, self).__init__() + self.batch_size = batch_size + self.input_size = input_size + self.num_classes = 11 + self.reshape = P.Reshape() + self.cast = P.Cast() + k = (1 / hidden_size) ** 0.5 + self.h1 = Tensor(np.zeros(shape=(batch_size, hidden_size)).astype(np.float16)) + self.c1 = Tensor(np.zeros(shape=(batch_size, hidden_size)).astype(np.float16)) + self.w1 = Parameter(np.random.uniform(-k, k, (4 * hidden_size, input_size + hidden_size, 1, 1)) + .astype(np.float16), name="w1") + self.w2 = Parameter(np.random.uniform(-k, k, (4 * hidden_size, hidden_size + hidden_size, 1, 1)) + .astype(np.float16), name="w2") + self.b1 = Parameter(np.random.uniform(-k, k, (4 * hidden_size, 1, 1, 1)).astype(np.float16), name="b1") + self.b2 = Parameter(np.random.uniform(-k, k, (4 * hidden_size, 1, 1, 1)).astype(np.float16), name="b2") + + self.h2 = Tensor(np.zeros(shape=(batch_size, hidden_size)).astype(np.float16)) + self.c2 = Tensor(np.zeros(shape=(batch_size, hidden_size)).astype(np.float16)) + + self.basic_lstm_cell = P.BasicLSTMCell(keep_prob=1.0, forget_bias=0.0, state_is_tuple=True, activation="tanh") + + self.fc_weight = np.random.random((self.num_classes, hidden_size)).astype(np.float32) + self.fc_bias = np.random.random((self.num_classes)).astype(np.float32) + + self.fc = nn.Dense(in_channels=hidden_size, out_channels=self.num_classes, weight_init=Tensor(self.fc_weight), + bias_init=Tensor(self.fc_bias)) + + self.fc.to_float(mstype.float32) + self.expand_dims = P.ExpandDims() + self.concat = P.Concat() + self.transpose = P.Transpose() + + def construct(self, x): + x = self.cast(x, mstype.float16) + x = self.transpose(x, (3, 0, 2, 1)) + x = self.reshape(x, (-1, self.batch_size, self.input_size)) + h1 = self.h1 + c1 = self.c1 + h2 = self.h2 + c2 = self.c2 + + c1, h1, _, _, _, _, _ = self.basic_lstm_cell(x[0, :, :], h1, c1, self.w1, self.b1) + c2, h2, _, _, _, _, _ = self.basic_lstm_cell(h1, h2, c2, self.w2, self.b2) + + h2_after_fc = self.fc(h2) + output = self.expand_dims(h2_after_fc, 0) + for i in range(1, F.shape(x)[0]): + c1, h1, _, _, _, _, _ = self.basic_lstm_cell(x[i, :, :], h1, c1, self.w1, self.b1) + c2, h2, _, _, _, _, _ = self.basic_lstm_cell(h1, h2, c2, self.w2, self.b2) + + h2_after_fc = self.fc(h2) + h2_after_fc = self.expand_dims(h2_after_fc, 0) + output = self.concat((output, h2_after_fc)) + + return output diff --git a/model_zoo/warpctc/src/warpctc_for_train.py b/model_zoo/warpctc/src/warpctc_for_train.py new file mode 100755 index 00000000000..d847f47c629 --- /dev/null +++ b/model_zoo/warpctc/src/warpctc_for_train.py @@ -0,0 +1,114 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Automatic differentiation with grad clip.""" +from mindspore.parallel._utils import (_get_device_num, _get_mirror_mean, + _get_parallel_mode) +from mindspore.train.parallel_utils import ParallelMode +from mindspore.common import dtype as mstype +from mindspore.ops import composite as C +from mindspore.ops import functional as F +from mindspore.ops import operations as P +from mindspore.nn.cell import Cell +from mindspore.nn.wrap.grad_reducer import DistributedGradReducer +import mindspore.nn as nn +from mindspore.common.tensor import Tensor +import numpy as np + +compute_norm = C.MultitypeFuncGraph("compute_norm") + + +@compute_norm.register("Tensor") +def _compute_norm(grad): + norm = nn.Norm() + norm = norm(F.cast(grad, mstype.float32)) + ret = F.expand_dims(F.cast(norm, mstype.float32), 0) + return ret + + +grad_div = C.MultitypeFuncGraph("grad_div") + + +@grad_div.register("Tensor", "Tensor") +def _grad_div(val, grad): + div = P.Div() + mul = P.Mul() + grad = mul(grad, 10.0) + ret = div(grad, val) + return ret + + +class TrainOneStepCellWithGradClip(Cell): + """ + Network training package class. + + Wraps the network with an optimizer. The resulting Cell be trained with input data and label. + Backward graph with grad clip will be created in the construct function to do parameter updating. + Different parallel modes are available to run the training. + + Args: + network (Cell): The training network. + optimizer (Cell): Optimizer for updating the weights. + sens (Number): The scaling number to be filled as the input of backpropagation. Default value is 1.0. + + Inputs: + - data (Tensor) - Tensor of shape :(N, ...). + - label (Tensor) - Tensor of shape :(N, ...). + + Outputs: + Tensor, a scalar Tensor with shape :math:`()`. + """ + + def __init__(self, network, optimizer, sens=1.0): + super(TrainOneStepCellWithGradClip, self).__init__(auto_prefix=False) + self.network = network + self.network.set_grad() + self.network.add_flags(defer_inline=True) + self.weights = optimizer.parameters + self.optimizer = optimizer + self.grad = C.GradOperation('grad', get_by_list=True, sens_param=True) + self.sens = sens + self.reducer_flag = False + self.grad_reducer = None + self.hyper_map = C.HyperMap() + self.greater = P.Greater() + self.select = P.Select() + self.norm = nn.Norm(keep_dims=True) + self.dtype = P.DType() + self.cast = P.Cast() + self.concat = P.Concat(axis=0) + self.ten = Tensor(np.array([10.0]).astype(np.float32)) + parallel_mode = _get_parallel_mode() + if parallel_mode in (ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL): + self.reducer_flag = True + if self.reducer_flag: + mean = _get_mirror_mean() + degree = _get_device_num() + self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree) + + def construct(self, data, label): + weights = self.weights + loss = self.network(data, label) + sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens) + grads = self.grad(self.network, weights)(data, label, sens) + norm = self.hyper_map(F.partial(compute_norm), grads) + norm = self.concat(norm) + norm = self.norm(norm) + cond = self.greater(norm, self.cast(self.ten, self.dtype(norm))) + clip_val = self.select(cond, norm, self.cast(self.ten, self.dtype(norm))) + grads = self.hyper_map(F.partial(grad_div, clip_val), grads) + if self.reducer_flag: + # apply grad reducer on grads + grads = self.grad_reducer(grads) + return F.depend(loss, self.optimizer(grads)) diff --git a/model_zoo/warpctc/train.py b/model_zoo/warpctc/train.py new file mode 100755 index 00000000000..651d2a73a4d --- /dev/null +++ b/model_zoo/warpctc/train.py @@ -0,0 +1,84 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""Warpctc training""" +import os +import math as m +import random +import argparse +import numpy as np +import mindspore.nn as nn +from mindspore import context +from mindspore import dataset as de +from mindspore.train.model import Model, ParallelMode +from mindspore.nn.wrap import WithLossCell +from mindspore.train.callback import TimeMonitor, LossMonitor, CheckpointConfig, ModelCheckpoint +from mindspore.communication.management import init + +from src.loss import CTCLoss +from src.config import config as cf +from src.dataset import create_dataset +from src.warpctc import StackedRNN +from src.warpctc_for_train import TrainOneStepCellWithGradClip +from src.lr_schedule import get_lr + +random.seed(1) +np.random.seed(1) +de.config.set_seed(1) + +parser = argparse.ArgumentParser(description="Warpctc training") +parser.add_argument("--run_distribute", type=bool, default=False, help="Run distribute, default is false.") +parser.add_argument('--device_num', type=int, default=1, help='Device num, default is 1.') +parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path, default is None') +args_opt = parser.parse_args() + +device_id = int(os.getenv('DEVICE_ID')) +context.set_context(mode=context.GRAPH_MODE, + device_target="Ascend", + save_graphs=False, + device_id=device_id) + +if __name__ == '__main__': + if args_opt.run_distribute: + context.reset_auto_parallel_context() + context.set_auto_parallel_context(device_num=args_opt.device_num, + parallel_mode=ParallelMode.DATA_PARALLEL, + mirror_mean=True) + init() + max_captcha_digits = cf.max_captcha_digits + input_size = m.ceil(cf.captcha_height / 64) * 64 * 3 + # create dataset + dataset = create_dataset(dataset_path=args_opt.dataset_path, repeat_num=cf.epoch_size, batch_size=cf.batch_size) + step_size = dataset.get_dataset_size() + # define lr + lr_init = cf.learning_rate if not args_opt.run_distribute else cf.learning_rate * args_opt.device_num + lr = get_lr(cf.epoch_size, step_size, lr_init) + # define loss + loss = CTCLoss(max_sequence_length=cf.captcha_width, max_label_length=max_captcha_digits, batch_size=cf.batch_size) + # define net + net = StackedRNN(input_size=input_size, batch_size=cf.batch_size, hidden_size=cf.hidden_size) + # define opt + opt = nn.SGD(params=net.trainable_params(), learning_rate=lr, momentum=cf.momentum) + net = WithLossCell(net, loss) + net = TrainOneStepCellWithGradClip(net, opt).set_train() + # define model + model = Model(net) + # define callbacks + callbacks = [LossMonitor(), TimeMonitor(data_size=step_size)] + if cf.save_checkpoint: + config_ck = CheckpointConfig(save_checkpoint_steps=cf.save_checkpoint_steps, + keep_checkpoint_max=cf.keep_checkpoint_max) + ckpt_cb = ModelCheckpoint(prefix="waptctc", directory=cf.save_checkpoint_path, config=config_ck) + callbacks.append(ckpt_cb) + model.train(cf.epoch_size, dataset, callbacks=callbacks)