add warpctc to modelzoo

This commit is contained in:
gengdongjie 2020-06-29 22:01:52 +08:00
parent cf4d317d3e
commit 03c57a1e8b
17 changed files with 1057 additions and 3 deletions

View File

@ -716,7 +716,7 @@ def get_bprop_basic_lstm_cell(self):
def bprop(x, h, c, w, b, out, dout):
_, _, it, jt, ft, ot, tanhct = out
dct, dht, _, _, _, _, _ = dout
dgate, dct_1 = basic_lstm_cell_cstate_grad(c, dht, dct, it, ft, jt, ot, tanhct)
dgate, dct_1 = basic_lstm_cell_cstate_grad(c, dht, dct, it, jt, ft, ot, tanhct)
dxt, dht = basic_lstm_cell_input_grad(dgate, w)
dw, db = basic_lstm_cell_weight_grad(F.depend(x, dxt), h, dgate)
return dxt, dht, dct_1, dw, db

View File

@ -29,8 +29,8 @@ basic_lstm_cell_c_state_grad_op_info = TBERegOp("BasicLSTMCellCStateGrad") \
.input(1, "dht", False, "required", "all") \
.input(2, "dct", False, "required", "all") \
.input(3, "it", False, "required", "all") \
.input(4, "ft", False, "required", "all") \
.input(5, "jt", False, "required", "all") \
.input(4, "jt", False, "required", "all") \
.input(5, "ft", False, "required", "all") \
.input(6, "ot", False, "required", "all") \
.input(7, "tanhct", False, "required", "all") \
.output(0, "dgate", False, "required", "all") \

137
model_zoo/warpctc/README.md Normal file
View File

@ -0,0 +1,137 @@
# Warpctc Example
## Description
These is an example of training Warpctc with self-generated captcha image dataset in MindSpore.
## Requirements
- Install [MindSpore](https://www.mindspore.cn/install/en).
- Generate captcha images.
> The [captcha](https://github.com/lepture/captcha) library can be used to generate captcha images. You can generate the train and test dataset by yourself or just run the script `scripts/run_process_data.sh`. By default, the shell script will generate 10000 test images and 50000 train images separately.
> ```
> $ cd scripts
> $ sh run_process_data.sh
>
> # after execution, you will find the dataset like the follows:
> .
> └─warpctc
> └─data
> ├─ train # train dataset
> └─ test # evaluate dataset
> ...
## Structure
```shell
.
└──warpct
├── README.md
├── script
├── run_distribute_train.sh # launch distributed training(8 pcs)
├── run_eval.sh # launch evaluation
├── run_process_data.sh # launch dataset generation
└── run_standalone_train.sh # launch standalone training(1 pcs)
├── src
├── config.py # parameter configuration
├── dataset.py # data preprocessing
├── loss.py # ctcloss definition
├── lr_generator.py # generate learning rate for each step
├── metric.py # accuracy metric for warpctc network
├── warpctc.py # warpctc network definition
└── warpctc_for_train.py # warp network with grad, loss and gradient clip
├── eval.py # eval net
├── process_data.py # dataset generation script
└── train.py # train net
```
## Parameter configuration
Parameters for both training and evaluation can be set in config.py.
```
"max_captcha_digits": 4, # max number of digits in each
"captcha_width": 160, # width of captcha images
"captcha_height": 64, # height of capthca images
"batch_size": 64, # batch size of input tensor
"epoch_size": 30, # only valid for taining, which is always 1 for inference
"hidden_size": 512, # hidden size in LSTM layers
"learning_rate": 0.01, # initial learning rate
"momentum": 0.9 # momentum of SGD optimizer
"save_checkpoint": True, # whether save checkpoint or not
"save_checkpoint_steps": 98, # the step interval between two checkpoints. By default, the last checkpoint will be saved after the last step
"keep_checkpoint_max": 30, # only keep the last keep_checkpoint_max checkpoint
"save_checkpoint_path": "./", # path to save checkpoint
```
## Running the example
### Train
#### Usage
```
# distributed training
Usage: sh run_distribute_train.sh [MINDSPORE_HCCL_CONFIG_PATH] [DATASET_PATH]
# standalone training
Usage: sh run_standalone_train.sh [DATASET_PATH]
```
#### Launch
```
# distribute training example
sh run_distribute_train.sh rank_table.json ../data/train
# standalone training example
sh run_standalone_train.sh ../data/train
```
> About rank_table.json, you can refer to the [distributed training tutorial](https://www.mindspore.cn/tutorial/en/master/advanced_use/distributed_training.html).
#### Result
Training result will be stored in folder `scripts`, whose name begins with "train" or "train_parallel". Under this, you can find checkpoint file together with result like the followings in log.
```
# distribute training result(8 pcs)
Epoch: [ 1/ 30], step: [ 98/ 98], loss: [0.5853/0.5853], time: [376813.7944]
Epoch: [ 2/ 30], step: [ 98/ 98], loss: [0.4007/0.4007], time: [75882.0951]
Epoch: [ 3/ 30], step: [ 98/ 98], loss: [0.0921/0.0921], time: [75150.9385]
Epoch: [ 4/ 30], step: [ 98/ 98], loss: [0.1472/0.1472], time: [75135.0193]
Epoch: [ 5/ 30], step: [ 98/ 98], loss: [0.0186/0.0186], time: [75199.5809]
...
```
### Evaluation
#### Usage
```
# evaluation
Usage: sh run_eval.sh [DATASET_PATH] [CHECKPOINT_PATH]
```
#### Launch
```
# evaluation example
sh run_eval.sh ../data/test warpctc-30-98.ckpt
```
> checkpoint can be produced in training process.
#### Result
Evaluation result will be stored in the example path, whose folder name is "eval". Under this, you can find result like the followings in log.
```
result: {'WarpCTCAccuracy': 0.9901472929936306}
```

65
model_zoo/warpctc/eval.py Executable file
View File

@ -0,0 +1,65 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Warpctc evaluation"""
import os
import math as m
import random
import argparse
import numpy as np
from mindspore import context
from mindspore import dataset as de
from mindspore.train.model import Model
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from src.loss import CTCLoss
from src.config import config as cf
from src.dataset import create_dataset
from src.warpctc import StackedRNN
from src.metric import WarpCTCAccuracy
random.seed(1)
np.random.seed(1)
de.config.set_seed(1)
parser = argparse.ArgumentParser(description="Warpctc training")
parser.add_argument("--dataset_path", type=str, default=None, help="Dataset, default is None.")
parser.add_argument("--checkpoint_path", type=str, default=None, help="checkpoint file path, default is None")
args_opt = parser.parse_args()
device_id = int(os.getenv('DEVICE_ID'))
context.set_context(mode=context.GRAPH_MODE,
device_target="Ascend",
save_graphs=False,
device_id=device_id)
if __name__ == '__main__':
max_captcha_digits = cf.max_captcha_digits
input_size = m.ceil(cf.captcha_height / 64) * 64 * 3
# create dataset
dataset = create_dataset(dataset_path=args_opt.dataset_path, repeat_num=1, batch_size=cf.batch_size)
step_size = dataset.get_dataset_size()
# define loss
loss = CTCLoss(max_sequence_length=cf.captcha_width, max_label_length=max_captcha_digits, batch_size=cf.batch_size)
# define net
net = StackedRNN(input_size=input_size, batch_size=cf.batch_size, hidden_size=cf.hidden_size)
# load checkpoint
param_dict = load_checkpoint(args_opt.checkpoint_path)
load_param_into_net(net, param_dict)
net.set_train(False)
# define model
model = Model(net, loss_fn=loss, metrics={'WarpCTCAccuracy': WarpCTCAccuracy()})
# start evaluation
res = model.eval(dataset)
print("result:", res, flush=True)

View File

@ -0,0 +1,71 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Generate train and test dataset"""
import os
import math as m
import random
from multiprocessing import Process
from captcha.image import ImageCaptcha
def _generate_captcha_per_process(path, total, start, end, img_width, img_height, max_digits):
captcha = ImageCaptcha(width=img_width, height=img_height)
filename_head = '{:0>' + str(len(str(total))) + '}-'
for i in range(start, end):
digits = ''
digits_length = random.randint(1, max_digits)
for _ in range(0, digits_length):
integer = random.randint(0, 9)
digits += str(integer)
captcha.write(digits, os.path.join(path, filename_head.format(i) + digits + '.png'))
def generate_captcha(name, img_num, img_width, img_height, max_digits, process_num=16):
"""
generate captcha images
Args:
name(str): name of folder, under which captcha images are saved in
img_num(int): number of generated captcha images
img_width(int): width of generated captcha images
img_height(int): height of generated captcha images
max_digits(int): max number of digits in each captcha images. For each captcha images, number of digits is in
range [1,max_digits]
process_num(int): number of process to generate captcha images, default is 16
"""
cur_script_path = os.path.dirname(os.path.realpath(__file__))
path = os.path.join(cur_script_path, "data", name)
print("Generating dataset [{}] under {}...".format(name, path))
if os.path.exists(path):
os.system("rm -rf {}".format(path))
os.system("mkdir -p {}".format(path))
img_num_per_thread = m.ceil(img_num / process_num)
processes = []
for i in range(process_num):
start = i * img_num_per_thread
end = start + img_num_per_thread if i != (process_num - 1) else img_num
p = Process(target=_generate_captcha_per_process,
args=(path, img_num, start, end, img_width, img_height, max_digits))
p.start()
processes.append(p)
for p in processes:
p.join()
print("Generating dataset [{}] finished, total number is {}!".format(name, img_num))
if __name__ == '__main__':
generate_captcha("test", img_num=10000, img_width=160, img_height=64, max_digits=4)
generate_captcha("train", img_num=50000, img_width=160, img_height=64, max_digits=4)

View File

@ -0,0 +1,62 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 2 ]; then
echo "Usage: sh run_distribute_train.sh [MINDSPORE_HCCL_CONFIG_PATH] [DATASET_PATH]"
exit 1
fi
get_real_path() {
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
PATH1=$(get_real_path $1)
PATH2=$(get_real_path $2)
if [ ! -f $PATH1 ]; then
echo "error: MINDSPORE_HCCL_CONFIG_PATH=$PATH1 is not a file"
exit 1
fi
if [ ! -d $PATH2 ]; then
echo "error: DATASET_PATH=$PATH2 is not a directory"
exit 1
fi
ulimit -u unlimited
export DEVICE_NUM=8
export RANK_SIZE=8
export MINDSPORE_HCCL_CONFIG_PATH=$PATH1
export RANK_TABLE_FILE=$PATH1
for ((i = 0; i < ${DEVICE_NUM}; i++)); do
export DEVICE_ID=$i
export RANK_ID=$i
rm -rf ./train_parallel$i
mkdir ./train_parallel$i
cp ../*.py ./train_parallel$i
cp *.sh ./train_parallel$i
cp -r ../src ./train_parallel$i
cd ./train_parallel$i || exit
echo "start training for rank $RANK_ID, device $DEVICE_ID"
env >env.log
python train.py --run_distribute=True --device_num=$DEVICE_NUM --dataset_path=$PATH2 &>log &
cd ..
done

View File

@ -0,0 +1,60 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 2 ]; then
echo "Usage: sh run_eval.sh [DATASET_PATH] [CHECKPOINT_PATH]"
exit 1
fi
get_real_path() {
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
PATH1=$(get_real_path $1)
PATH2=$(get_real_path $2)
if [ ! -d $PATH1 ]; then
echo "error: DATASET_PATH=$PATH1 is not a directory"
exit 1
fi
if [ ! -f $PATH2 ]; then
echo "error: CHECKPOINT_PATH=$PATH2 is not a file"
exit 1
fi
ulimit -u unlimited
export DEVICE_NUM=1
export DEVICE_ID=0
export RANK_SIZE=$DEVICE_NUM
export RANK_ID=0
if [ -d "eval" ]; then
rm -rf ./eval
fi
mkdir ./eval
cp ../*.py ./eval
cp *.sh ./eval
cp -r ../src ./eval
cd ./eval || exit
env >env.log
echo "start evaluation for device $DEVICE_ID"
python eval.py --dataset_path=$PATH1 --checkpoint_path=$PATH2 &>log &
cd ..

View File

@ -0,0 +1,20 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
CUR_PATH=$(dirname $PWD/$0)
cd $CUR_PATH/../ &&
python process_data.py &&
cd - || exit

View File

@ -0,0 +1,54 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 1 ]; then
echo "Usage: sh run_standalone_train.sh [DATASET_PATH]"
exit 1
fi
get_real_path() {
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
PATH1=$(get_real_path $1)
if [ ! -d $PATH1 ]; then
echo "error: DATASET_PATH=$PATH1 is not a directory"
exit 1
fi
ulimit -u unlimited
export DEVICE_NUM=1
export DEVICE_ID=0
export RANK_ID=0
export RANK_SIZE=1
if [ -d "train" ]; then
rm -rf ./train
fi
mkdir ./train
cp ../*.py ./train
cp *.sh ./train
cp -r ../src ./train
cd ./train || exit
echo "start training for device $DEVICE_ID"
env >env.log
python train.py --dataset=$PATH1 &>log &
cd ..

31
model_zoo/warpctc/src/config.py Executable file
View File

@ -0,0 +1,31 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Network parameters."""
from easydict import EasyDict
config = EasyDict({
"max_captcha_digits": 4,
"captcha_width": 160,
"captcha_height": 64,
"batch_size": 64,
"epoch_size": 30,
"hidden_size": 512,
"learning_rate": 0.01,
"momentum": 0.9,
"save_checkpoint": True,
"save_checkpoint_steps": 98,
"keep_checkpoint_max": 30,
"save_checkpoint_path": "./",
})

View File

@ -0,0 +1,92 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Dataset preprocessing."""
import os
import math as m
import numpy as np
import mindspore.common.dtype as mstype
import mindspore.dataset.engine as de
import mindspore.dataset.transforms.c_transforms as c
import mindspore.dataset.transforms.vision.c_transforms as vc
from PIL import Image
from src.config import config as cf
class _CaptchaDataset():
"""
create train or evaluation dataset for warpctc
Args:
img_root_dir(str): root path of images
max_captcha_digits(int): max number of digits in images.
blank(int): value reserved for blank label, default is 10. When parsing label from image file names, if label
length is less than max_captcha_digits, the remaining labels are padding with blank.
"""
def __init__(self, img_root_dir, max_captcha_digits, blank=10):
if not os.path.exists(img_root_dir):
raise RuntimeError("the input image dir {} is invalid!".format(img_root_dir))
self.img_root_dir = img_root_dir
self.img_names = [i for i in os.listdir(img_root_dir) if i.endswith('.png')]
self.max_captcha_digits = max_captcha_digits
self.blank = blank
def __len__(self):
return len(self.img_names)
def __getitem__(self, item):
img_name = self.img_names[item]
im = Image.open(os.path.join(self.img_root_dir, img_name))
r, g, b = im.split()
im = Image.merge("RGB", (b, g, r))
image = np.array(im)
label_str = os.path.splitext(img_name)[0]
label_str = label_str[label_str.find('-') + 1:]
label = [int(i) for i in label_str]
label.extend([int(self.blank)] * (self.max_captcha_digits - len(label)))
label = np.array(label)
return image, label
def create_dataset(dataset_path, repeat_num=1, batch_size=1):
"""
create train or evaluation dataset for warpctc
Args:
dataset_path(int): dataset path
repeat_num(int): dataset repetition num, default is 1
batch_size(int): batch size of generated dataset, default is 1
"""
rank_size = int(os.environ.get("RANK_SIZE")) if os.environ.get("RANK_SIZE") else 1
rank_id = int(os.environ.get("RANK_ID")) if os.environ.get("RANK_ID") else 0
dataset = _CaptchaDataset(dataset_path, cf.max_captcha_digits)
ds = de.GeneratorDataset(dataset, ["image", "label"], shuffle=True, num_shards=rank_size, shard_id=rank_id)
ds.set_dataset_size(m.ceil(len(dataset) / rank_size))
image_trans = [
vc.Rescale(1.0 / 255.0, 0.0),
vc.Normalize([0.9010, 0.9049, 0.9025], std=[0.1521, 0.1347, 0.1458]),
vc.Resize((m.ceil(cf.captcha_height / 16) * 16, cf.captcha_width)),
vc.HWC2CHW()
]
label_trans = [
c.TypeCast(mstype.int32)
]
ds = ds.map(input_columns=["image"], num_parallel_workers=8, operations=image_trans)
ds = ds.map(input_columns=["label"], num_parallel_workers=8, operations=label_trans)
ds = ds.batch(batch_size)
ds = ds.repeat(repeat_num)
return ds

49
model_zoo/warpctc/src/loss.py Executable file
View File

@ -0,0 +1,49 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""CTC Loss."""
import numpy as np
from mindspore.nn.loss.loss import _Loss
from mindspore import Tensor, Parameter
from mindspore.common import dtype as mstype
from mindspore.ops import operations as P
class CTCLoss(_Loss):
"""
CTCLoss definition
Args:
max_sequence_length(int): max number of sequence length. For captcha images, the value is equal to image
width
max_label_length(int): max number of label length for each input.
batch_size(int): batch size of input logits
"""
def __init__(self, max_sequence_length, max_label_length, batch_size):
super(CTCLoss, self).__init__()
self.sequence_length = Parameter(Tensor(np.array([max_sequence_length] * batch_size), mstype.int32),
name="sequence_length")
labels_indices = []
for i in range(batch_size):
for j in range(max_label_length):
labels_indices.append([i, j])
self.labels_indices = Parameter(Tensor(np.array(labels_indices), mstype.int64), name="labels_indices")
self.reshape = P.Reshape()
self.ctc_loss = P.CTCLoss(ctc_merge_repeated=True)
def construct(self, logit, label):
labels_values = self.reshape(label, (-1,))
loss, _ = self.ctc_loss(logit, self.labels_indices, labels_values, self.sequence_length)
return loss

View File

@ -0,0 +1,36 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Learning rate generator."""
def get_lr(epoch_size, step_size, lr_init):
"""
generate learning rate for each step, which decays in every 10 epoch
Args:
epoch_size(int): total epoch number
step_size(int): total step number in each step
lr_init(int): initial learning rate
Returns:
List, learning rate array
"""
lr = lr_init
lrs = []
for i in range(1, epoch_size + 1):
if i % 10 == 0:
lr *= 0.1
lrs.extend([lr for _ in range(step_size)])
return lrs

89
model_zoo/warpctc/src/metric.py Executable file
View File

@ -0,0 +1,89 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Metric for accuracy evaluation."""
from mindspore import nn
BLANK_LABLE = 10
class WarpCTCAccuracy(nn.Metric):
"""
Define accuracy metric for warpctc network.
"""
def __init__(self):
super(WarpCTCAccuracy).__init__()
self._correct_num = 0
self._total_num = 0
self._count = 0
def clear(self):
self._correct_num = 0
self._total_num = 0
def update(self, *inputs):
if len(inputs) != 2:
raise ValueError('WarpCTCAccuracy need 2 inputs (y_pred, y), but got {}'.format(len(inputs)))
y_pred = self._convert_data(inputs[0])
y = self._convert_data(inputs[1])
self._count += 1
pred_lbls = self._get_prediction(y_pred)
for b_idx, target in enumerate(y):
if self._is_eq(pred_lbls[b_idx], target):
self._correct_num += 1
self._total_num += 1
def eval(self):
if self._total_num == 0:
raise RuntimeError('Accuary can not be calculated, because the number of samples is 0.')
return self._correct_num / self._total_num
@staticmethod
def _is_eq(pred_lbl, target):
"""
check whether predict label is equal to target label
"""
target = target.tolist()
pred_diff = len(target) - len(pred_lbl)
if pred_diff > 0:
# padding by BLANK_LABLE
pred_lbl.extend([BLANK_LABLE] * pred_diff)
return pred_lbl == target
@staticmethod
def _get_prediction(y_pred):
"""
parse predict result to labels
"""
seq_len, batch_size, _ = y_pred.shape
indices = y_pred.argmax(axis=2)
lens = [seq_len] * batch_size
pred_lbls = []
for i in range(batch_size):
idx = indices[:, i]
last_idx = BLANK_LABLE
pred_lbl = []
for j in range(lens[i]):
cur_idx = idx[j]
if cur_idx not in [last_idx, BLANK_LABLE]:
pred_lbl.append(cur_idx)
last_idx = cur_idx
pred_lbls.append(pred_lbl)
return pred_lbls

View File

@ -0,0 +1,90 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Warpctc network definition."""
import numpy as np
import mindspore.nn as nn
from mindspore import Tensor, Parameter
from mindspore.common import dtype as mstype
from mindspore.ops import operations as P
from mindspore.ops import functional as F
class StackedRNN(nn.Cell):
"""
Define a stacked RNN network which contains two LSTM layers and one full-connect layer.
Args:
input_size(int): Size of time sequence. Usually, the input_size is equal to three times of image height for
captcha images.
batch_size(int): batch size of input data, default is 64
hidden_size(int): the hidden size in LSTM layers, default is 512
"""
def __init__(self, input_size, batch_size=64, hidden_size=512):
super(StackedRNN, self).__init__()
self.batch_size = batch_size
self.input_size = input_size
self.num_classes = 11
self.reshape = P.Reshape()
self.cast = P.Cast()
k = (1 / hidden_size) ** 0.5
self.h1 = Tensor(np.zeros(shape=(batch_size, hidden_size)).astype(np.float16))
self.c1 = Tensor(np.zeros(shape=(batch_size, hidden_size)).astype(np.float16))
self.w1 = Parameter(np.random.uniform(-k, k, (4 * hidden_size, input_size + hidden_size, 1, 1))
.astype(np.float16), name="w1")
self.w2 = Parameter(np.random.uniform(-k, k, (4 * hidden_size, hidden_size + hidden_size, 1, 1))
.astype(np.float16), name="w2")
self.b1 = Parameter(np.random.uniform(-k, k, (4 * hidden_size, 1, 1, 1)).astype(np.float16), name="b1")
self.b2 = Parameter(np.random.uniform(-k, k, (4 * hidden_size, 1, 1, 1)).astype(np.float16), name="b2")
self.h2 = Tensor(np.zeros(shape=(batch_size, hidden_size)).astype(np.float16))
self.c2 = Tensor(np.zeros(shape=(batch_size, hidden_size)).astype(np.float16))
self.basic_lstm_cell = P.BasicLSTMCell(keep_prob=1.0, forget_bias=0.0, state_is_tuple=True, activation="tanh")
self.fc_weight = np.random.random((self.num_classes, hidden_size)).astype(np.float32)
self.fc_bias = np.random.random((self.num_classes)).astype(np.float32)
self.fc = nn.Dense(in_channels=hidden_size, out_channels=self.num_classes, weight_init=Tensor(self.fc_weight),
bias_init=Tensor(self.fc_bias))
self.fc.to_float(mstype.float32)
self.expand_dims = P.ExpandDims()
self.concat = P.Concat()
self.transpose = P.Transpose()
def construct(self, x):
x = self.cast(x, mstype.float16)
x = self.transpose(x, (3, 0, 2, 1))
x = self.reshape(x, (-1, self.batch_size, self.input_size))
h1 = self.h1
c1 = self.c1
h2 = self.h2
c2 = self.c2
c1, h1, _, _, _, _, _ = self.basic_lstm_cell(x[0, :, :], h1, c1, self.w1, self.b1)
c2, h2, _, _, _, _, _ = self.basic_lstm_cell(h1, h2, c2, self.w2, self.b2)
h2_after_fc = self.fc(h2)
output = self.expand_dims(h2_after_fc, 0)
for i in range(1, F.shape(x)[0]):
c1, h1, _, _, _, _, _ = self.basic_lstm_cell(x[i, :, :], h1, c1, self.w1, self.b1)
c2, h2, _, _, _, _, _ = self.basic_lstm_cell(h1, h2, c2, self.w2, self.b2)
h2_after_fc = self.fc(h2)
h2_after_fc = self.expand_dims(h2_after_fc, 0)
output = self.concat((output, h2_after_fc))
return output

View File

@ -0,0 +1,114 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Automatic differentiation with grad clip."""
from mindspore.parallel._utils import (_get_device_num, _get_mirror_mean,
_get_parallel_mode)
from mindspore.train.parallel_utils import ParallelMode
from mindspore.common import dtype as mstype
from mindspore.ops import composite as C
from mindspore.ops import functional as F
from mindspore.ops import operations as P
from mindspore.nn.cell import Cell
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
import mindspore.nn as nn
from mindspore.common.tensor import Tensor
import numpy as np
compute_norm = C.MultitypeFuncGraph("compute_norm")
@compute_norm.register("Tensor")
def _compute_norm(grad):
norm = nn.Norm()
norm = norm(F.cast(grad, mstype.float32))
ret = F.expand_dims(F.cast(norm, mstype.float32), 0)
return ret
grad_div = C.MultitypeFuncGraph("grad_div")
@grad_div.register("Tensor", "Tensor")
def _grad_div(val, grad):
div = P.Div()
mul = P.Mul()
grad = mul(grad, 10.0)
ret = div(grad, val)
return ret
class TrainOneStepCellWithGradClip(Cell):
"""
Network training package class.
Wraps the network with an optimizer. The resulting Cell be trained with input data and label.
Backward graph with grad clip will be created in the construct function to do parameter updating.
Different parallel modes are available to run the training.
Args:
network (Cell): The training network.
optimizer (Cell): Optimizer for updating the weights.
sens (Number): The scaling number to be filled as the input of backpropagation. Default value is 1.0.
Inputs:
- data (Tensor) - Tensor of shape :(N, ...).
- label (Tensor) - Tensor of shape :(N, ...).
Outputs:
Tensor, a scalar Tensor with shape :math:`()`.
"""
def __init__(self, network, optimizer, sens=1.0):
super(TrainOneStepCellWithGradClip, self).__init__(auto_prefix=False)
self.network = network
self.network.set_grad()
self.network.add_flags(defer_inline=True)
self.weights = optimizer.parameters
self.optimizer = optimizer
self.grad = C.GradOperation('grad', get_by_list=True, sens_param=True)
self.sens = sens
self.reducer_flag = False
self.grad_reducer = None
self.hyper_map = C.HyperMap()
self.greater = P.Greater()
self.select = P.Select()
self.norm = nn.Norm(keep_dims=True)
self.dtype = P.DType()
self.cast = P.Cast()
self.concat = P.Concat(axis=0)
self.ten = Tensor(np.array([10.0]).astype(np.float32))
parallel_mode = _get_parallel_mode()
if parallel_mode in (ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL):
self.reducer_flag = True
if self.reducer_flag:
mean = _get_mirror_mean()
degree = _get_device_num()
self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree)
def construct(self, data, label):
weights = self.weights
loss = self.network(data, label)
sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens)
grads = self.grad(self.network, weights)(data, label, sens)
norm = self.hyper_map(F.partial(compute_norm), grads)
norm = self.concat(norm)
norm = self.norm(norm)
cond = self.greater(norm, self.cast(self.ten, self.dtype(norm)))
clip_val = self.select(cond, norm, self.cast(self.ten, self.dtype(norm)))
grads = self.hyper_map(F.partial(grad_div, clip_val), grads)
if self.reducer_flag:
# apply grad reducer on grads
grads = self.grad_reducer(grads)
return F.depend(loss, self.optimizer(grads))

84
model_zoo/warpctc/train.py Executable file
View File

@ -0,0 +1,84 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Warpctc training"""
import os
import math as m
import random
import argparse
import numpy as np
import mindspore.nn as nn
from mindspore import context
from mindspore import dataset as de
from mindspore.train.model import Model, ParallelMode
from mindspore.nn.wrap import WithLossCell
from mindspore.train.callback import TimeMonitor, LossMonitor, CheckpointConfig, ModelCheckpoint
from mindspore.communication.management import init
from src.loss import CTCLoss
from src.config import config as cf
from src.dataset import create_dataset
from src.warpctc import StackedRNN
from src.warpctc_for_train import TrainOneStepCellWithGradClip
from src.lr_schedule import get_lr
random.seed(1)
np.random.seed(1)
de.config.set_seed(1)
parser = argparse.ArgumentParser(description="Warpctc training")
parser.add_argument("--run_distribute", type=bool, default=False, help="Run distribute, default is false.")
parser.add_argument('--device_num', type=int, default=1, help='Device num, default is 1.')
parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path, default is None')
args_opt = parser.parse_args()
device_id = int(os.getenv('DEVICE_ID'))
context.set_context(mode=context.GRAPH_MODE,
device_target="Ascend",
save_graphs=False,
device_id=device_id)
if __name__ == '__main__':
if args_opt.run_distribute:
context.reset_auto_parallel_context()
context.set_auto_parallel_context(device_num=args_opt.device_num,
parallel_mode=ParallelMode.DATA_PARALLEL,
mirror_mean=True)
init()
max_captcha_digits = cf.max_captcha_digits
input_size = m.ceil(cf.captcha_height / 64) * 64 * 3
# create dataset
dataset = create_dataset(dataset_path=args_opt.dataset_path, repeat_num=cf.epoch_size, batch_size=cf.batch_size)
step_size = dataset.get_dataset_size()
# define lr
lr_init = cf.learning_rate if not args_opt.run_distribute else cf.learning_rate * args_opt.device_num
lr = get_lr(cf.epoch_size, step_size, lr_init)
# define loss
loss = CTCLoss(max_sequence_length=cf.captcha_width, max_label_length=max_captcha_digits, batch_size=cf.batch_size)
# define net
net = StackedRNN(input_size=input_size, batch_size=cf.batch_size, hidden_size=cf.hidden_size)
# define opt
opt = nn.SGD(params=net.trainable_params(), learning_rate=lr, momentum=cf.momentum)
net = WithLossCell(net, loss)
net = TrainOneStepCellWithGradClip(net, opt).set_train()
# define model
model = Model(net)
# define callbacks
callbacks = [LossMonitor(), TimeMonitor(data_size=step_size)]
if cf.save_checkpoint:
config_ck = CheckpointConfig(save_checkpoint_steps=cf.save_checkpoint_steps,
keep_checkpoint_max=cf.keep_checkpoint_max)
ckpt_cb = ModelCheckpoint(prefix="waptctc", directory=cf.save_checkpoint_path, config=config_ck)
callbacks.append(ckpt_cb)
model.train(cf.epoch_size, dataset, callbacks=callbacks)