forked from mindspore-Ecosystem/mindspore
add warpctc to modelzoo
This commit is contained in:
parent
cf4d317d3e
commit
03c57a1e8b
|
@ -716,7 +716,7 @@ def get_bprop_basic_lstm_cell(self):
|
|||
def bprop(x, h, c, w, b, out, dout):
|
||||
_, _, it, jt, ft, ot, tanhct = out
|
||||
dct, dht, _, _, _, _, _ = dout
|
||||
dgate, dct_1 = basic_lstm_cell_cstate_grad(c, dht, dct, it, ft, jt, ot, tanhct)
|
||||
dgate, dct_1 = basic_lstm_cell_cstate_grad(c, dht, dct, it, jt, ft, ot, tanhct)
|
||||
dxt, dht = basic_lstm_cell_input_grad(dgate, w)
|
||||
dw, db = basic_lstm_cell_weight_grad(F.depend(x, dxt), h, dgate)
|
||||
return dxt, dht, dct_1, dw, db
|
||||
|
|
|
@ -29,8 +29,8 @@ basic_lstm_cell_c_state_grad_op_info = TBERegOp("BasicLSTMCellCStateGrad") \
|
|||
.input(1, "dht", False, "required", "all") \
|
||||
.input(2, "dct", False, "required", "all") \
|
||||
.input(3, "it", False, "required", "all") \
|
||||
.input(4, "ft", False, "required", "all") \
|
||||
.input(5, "jt", False, "required", "all") \
|
||||
.input(4, "jt", False, "required", "all") \
|
||||
.input(5, "ft", False, "required", "all") \
|
||||
.input(6, "ot", False, "required", "all") \
|
||||
.input(7, "tanhct", False, "required", "all") \
|
||||
.output(0, "dgate", False, "required", "all") \
|
||||
|
|
|
@ -0,0 +1,137 @@
|
|||
# Warpctc Example
|
||||
|
||||
## Description
|
||||
|
||||
These is an example of training Warpctc with self-generated captcha image dataset in MindSpore.
|
||||
|
||||
## Requirements
|
||||
|
||||
- Install [MindSpore](https://www.mindspore.cn/install/en).
|
||||
|
||||
- Generate captcha images.
|
||||
|
||||
> The [captcha](https://github.com/lepture/captcha) library can be used to generate captcha images. You can generate the train and test dataset by yourself or just run the script `scripts/run_process_data.sh`. By default, the shell script will generate 10000 test images and 50000 train images separately.
|
||||
> ```
|
||||
> $ cd scripts
|
||||
> $ sh run_process_data.sh
|
||||
>
|
||||
> # after execution, you will find the dataset like the follows:
|
||||
> .
|
||||
> └─warpctc
|
||||
> └─data
|
||||
> ├─ train # train dataset
|
||||
> └─ test # evaluate dataset
|
||||
> ...
|
||||
|
||||
|
||||
## Structure
|
||||
|
||||
```shell
|
||||
.
|
||||
└──warpct
|
||||
├── README.md
|
||||
├── script
|
||||
├── run_distribute_train.sh # launch distributed training(8 pcs)
|
||||
├── run_eval.sh # launch evaluation
|
||||
├── run_process_data.sh # launch dataset generation
|
||||
└── run_standalone_train.sh # launch standalone training(1 pcs)
|
||||
├── src
|
||||
├── config.py # parameter configuration
|
||||
├── dataset.py # data preprocessing
|
||||
├── loss.py # ctcloss definition
|
||||
├── lr_generator.py # generate learning rate for each step
|
||||
├── metric.py # accuracy metric for warpctc network
|
||||
├── warpctc.py # warpctc network definition
|
||||
└── warpctc_for_train.py # warp network with grad, loss and gradient clip
|
||||
├── eval.py # eval net
|
||||
├── process_data.py # dataset generation script
|
||||
└── train.py # train net
|
||||
```
|
||||
|
||||
|
||||
## Parameter configuration
|
||||
|
||||
Parameters for both training and evaluation can be set in config.py.
|
||||
|
||||
```
|
||||
"max_captcha_digits": 4, # max number of digits in each
|
||||
"captcha_width": 160, # width of captcha images
|
||||
"captcha_height": 64, # height of capthca images
|
||||
"batch_size": 64, # batch size of input tensor
|
||||
"epoch_size": 30, # only valid for taining, which is always 1 for inference
|
||||
"hidden_size": 512, # hidden size in LSTM layers
|
||||
"learning_rate": 0.01, # initial learning rate
|
||||
"momentum": 0.9 # momentum of SGD optimizer
|
||||
"save_checkpoint": True, # whether save checkpoint or not
|
||||
"save_checkpoint_steps": 98, # the step interval between two checkpoints. By default, the last checkpoint will be saved after the last step
|
||||
"keep_checkpoint_max": 30, # only keep the last keep_checkpoint_max checkpoint
|
||||
"save_checkpoint_path": "./", # path to save checkpoint
|
||||
```
|
||||
|
||||
## Running the example
|
||||
|
||||
### Train
|
||||
|
||||
#### Usage
|
||||
|
||||
```
|
||||
# distributed training
|
||||
Usage: sh run_distribute_train.sh [MINDSPORE_HCCL_CONFIG_PATH] [DATASET_PATH]
|
||||
|
||||
# standalone training
|
||||
Usage: sh run_standalone_train.sh [DATASET_PATH]
|
||||
```
|
||||
|
||||
|
||||
#### Launch
|
||||
|
||||
```
|
||||
# distribute training example
|
||||
sh run_distribute_train.sh rank_table.json ../data/train
|
||||
|
||||
# standalone training example
|
||||
sh run_standalone_train.sh ../data/train
|
||||
```
|
||||
|
||||
> About rank_table.json, you can refer to the [distributed training tutorial](https://www.mindspore.cn/tutorial/en/master/advanced_use/distributed_training.html).
|
||||
|
||||
#### Result
|
||||
|
||||
Training result will be stored in folder `scripts`, whose name begins with "train" or "train_parallel". Under this, you can find checkpoint file together with result like the followings in log.
|
||||
|
||||
```
|
||||
# distribute training result(8 pcs)
|
||||
Epoch: [ 1/ 30], step: [ 98/ 98], loss: [0.5853/0.5853], time: [376813.7944]
|
||||
Epoch: [ 2/ 30], step: [ 98/ 98], loss: [0.4007/0.4007], time: [75882.0951]
|
||||
Epoch: [ 3/ 30], step: [ 98/ 98], loss: [0.0921/0.0921], time: [75150.9385]
|
||||
Epoch: [ 4/ 30], step: [ 98/ 98], loss: [0.1472/0.1472], time: [75135.0193]
|
||||
Epoch: [ 5/ 30], step: [ 98/ 98], loss: [0.0186/0.0186], time: [75199.5809]
|
||||
...
|
||||
```
|
||||
|
||||
|
||||
### Evaluation
|
||||
|
||||
#### Usage
|
||||
|
||||
```
|
||||
# evaluation
|
||||
Usage: sh run_eval.sh [DATASET_PATH] [CHECKPOINT_PATH]
|
||||
```
|
||||
|
||||
#### Launch
|
||||
|
||||
```
|
||||
# evaluation example
|
||||
sh run_eval.sh ../data/test warpctc-30-98.ckpt
|
||||
```
|
||||
|
||||
> checkpoint can be produced in training process.
|
||||
|
||||
#### Result
|
||||
|
||||
Evaluation result will be stored in the example path, whose folder name is "eval". Under this, you can find result like the followings in log.
|
||||
|
||||
```
|
||||
result: {'WarpCTCAccuracy': 0.9901472929936306}
|
||||
```
|
|
@ -0,0 +1,65 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Warpctc evaluation"""
|
||||
import os
|
||||
import math as m
|
||||
import random
|
||||
import argparse
|
||||
import numpy as np
|
||||
from mindspore import context
|
||||
from mindspore import dataset as de
|
||||
from mindspore.train.model import Model
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
|
||||
from src.loss import CTCLoss
|
||||
from src.config import config as cf
|
||||
from src.dataset import create_dataset
|
||||
from src.warpctc import StackedRNN
|
||||
from src.metric import WarpCTCAccuracy
|
||||
|
||||
random.seed(1)
|
||||
np.random.seed(1)
|
||||
de.config.set_seed(1)
|
||||
|
||||
parser = argparse.ArgumentParser(description="Warpctc training")
|
||||
parser.add_argument("--dataset_path", type=str, default=None, help="Dataset, default is None.")
|
||||
parser.add_argument("--checkpoint_path", type=str, default=None, help="checkpoint file path, default is None")
|
||||
args_opt = parser.parse_args()
|
||||
|
||||
device_id = int(os.getenv('DEVICE_ID'))
|
||||
context.set_context(mode=context.GRAPH_MODE,
|
||||
device_target="Ascend",
|
||||
save_graphs=False,
|
||||
device_id=device_id)
|
||||
|
||||
if __name__ == '__main__':
|
||||
max_captcha_digits = cf.max_captcha_digits
|
||||
input_size = m.ceil(cf.captcha_height / 64) * 64 * 3
|
||||
# create dataset
|
||||
dataset = create_dataset(dataset_path=args_opt.dataset_path, repeat_num=1, batch_size=cf.batch_size)
|
||||
step_size = dataset.get_dataset_size()
|
||||
# define loss
|
||||
loss = CTCLoss(max_sequence_length=cf.captcha_width, max_label_length=max_captcha_digits, batch_size=cf.batch_size)
|
||||
# define net
|
||||
net = StackedRNN(input_size=input_size, batch_size=cf.batch_size, hidden_size=cf.hidden_size)
|
||||
# load checkpoint
|
||||
param_dict = load_checkpoint(args_opt.checkpoint_path)
|
||||
load_param_into_net(net, param_dict)
|
||||
net.set_train(False)
|
||||
# define model
|
||||
model = Model(net, loss_fn=loss, metrics={'WarpCTCAccuracy': WarpCTCAccuracy()})
|
||||
# start evaluation
|
||||
res = model.eval(dataset)
|
||||
print("result:", res, flush=True)
|
|
@ -0,0 +1,71 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Generate train and test dataset"""
|
||||
import os
|
||||
import math as m
|
||||
import random
|
||||
from multiprocessing import Process
|
||||
from captcha.image import ImageCaptcha
|
||||
|
||||
|
||||
def _generate_captcha_per_process(path, total, start, end, img_width, img_height, max_digits):
|
||||
captcha = ImageCaptcha(width=img_width, height=img_height)
|
||||
filename_head = '{:0>' + str(len(str(total))) + '}-'
|
||||
for i in range(start, end):
|
||||
digits = ''
|
||||
digits_length = random.randint(1, max_digits)
|
||||
for _ in range(0, digits_length):
|
||||
integer = random.randint(0, 9)
|
||||
digits += str(integer)
|
||||
captcha.write(digits, os.path.join(path, filename_head.format(i) + digits + '.png'))
|
||||
|
||||
|
||||
def generate_captcha(name, img_num, img_width, img_height, max_digits, process_num=16):
|
||||
"""
|
||||
generate captcha images
|
||||
|
||||
Args:
|
||||
name(str): name of folder, under which captcha images are saved in
|
||||
img_num(int): number of generated captcha images
|
||||
img_width(int): width of generated captcha images
|
||||
img_height(int): height of generated captcha images
|
||||
max_digits(int): max number of digits in each captcha images. For each captcha images, number of digits is in
|
||||
range [1,max_digits]
|
||||
process_num(int): number of process to generate captcha images, default is 16
|
||||
"""
|
||||
cur_script_path = os.path.dirname(os.path.realpath(__file__))
|
||||
path = os.path.join(cur_script_path, "data", name)
|
||||
print("Generating dataset [{}] under {}...".format(name, path))
|
||||
if os.path.exists(path):
|
||||
os.system("rm -rf {}".format(path))
|
||||
os.system("mkdir -p {}".format(path))
|
||||
img_num_per_thread = m.ceil(img_num / process_num)
|
||||
|
||||
processes = []
|
||||
for i in range(process_num):
|
||||
start = i * img_num_per_thread
|
||||
end = start + img_num_per_thread if i != (process_num - 1) else img_num
|
||||
p = Process(target=_generate_captcha_per_process,
|
||||
args=(path, img_num, start, end, img_width, img_height, max_digits))
|
||||
p.start()
|
||||
processes.append(p)
|
||||
for p in processes:
|
||||
p.join()
|
||||
print("Generating dataset [{}] finished, total number is {}!".format(name, img_num))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
generate_captcha("test", img_num=10000, img_width=160, img_height=64, max_digits=4)
|
||||
generate_captcha("train", img_num=50000, img_width=160, img_height=64, max_digits=4)
|
|
@ -0,0 +1,62 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
if [ $# != 2 ]; then
|
||||
echo "Usage: sh run_distribute_train.sh [MINDSPORE_HCCL_CONFIG_PATH] [DATASET_PATH]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
get_real_path() {
|
||||
if [ "${1:0:1}" == "/" ]; then
|
||||
echo "$1"
|
||||
else
|
||||
echo "$(realpath -m $PWD/$1)"
|
||||
fi
|
||||
}
|
||||
|
||||
PATH1=$(get_real_path $1)
|
||||
PATH2=$(get_real_path $2)
|
||||
|
||||
if [ ! -f $PATH1 ]; then
|
||||
echo "error: MINDSPORE_HCCL_CONFIG_PATH=$PATH1 is not a file"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ ! -d $PATH2 ]; then
|
||||
echo "error: DATASET_PATH=$PATH2 is not a directory"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
ulimit -u unlimited
|
||||
export DEVICE_NUM=8
|
||||
export RANK_SIZE=8
|
||||
export MINDSPORE_HCCL_CONFIG_PATH=$PATH1
|
||||
export RANK_TABLE_FILE=$PATH1
|
||||
|
||||
for ((i = 0; i < ${DEVICE_NUM}; i++)); do
|
||||
export DEVICE_ID=$i
|
||||
export RANK_ID=$i
|
||||
rm -rf ./train_parallel$i
|
||||
mkdir ./train_parallel$i
|
||||
cp ../*.py ./train_parallel$i
|
||||
cp *.sh ./train_parallel$i
|
||||
cp -r ../src ./train_parallel$i
|
||||
cd ./train_parallel$i || exit
|
||||
echo "start training for rank $RANK_ID, device $DEVICE_ID"
|
||||
env >env.log
|
||||
python train.py --run_distribute=True --device_num=$DEVICE_NUM --dataset_path=$PATH2 &>log &
|
||||
cd ..
|
||||
done
|
|
@ -0,0 +1,60 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
if [ $# != 2 ]; then
|
||||
echo "Usage: sh run_eval.sh [DATASET_PATH] [CHECKPOINT_PATH]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
get_real_path() {
|
||||
if [ "${1:0:1}" == "/" ]; then
|
||||
echo "$1"
|
||||
else
|
||||
echo "$(realpath -m $PWD/$1)"
|
||||
fi
|
||||
}
|
||||
|
||||
PATH1=$(get_real_path $1)
|
||||
PATH2=$(get_real_path $2)
|
||||
|
||||
if [ ! -d $PATH1 ]; then
|
||||
echo "error: DATASET_PATH=$PATH1 is not a directory"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ ! -f $PATH2 ]; then
|
||||
echo "error: CHECKPOINT_PATH=$PATH2 is not a file"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
ulimit -u unlimited
|
||||
export DEVICE_NUM=1
|
||||
export DEVICE_ID=0
|
||||
export RANK_SIZE=$DEVICE_NUM
|
||||
export RANK_ID=0
|
||||
|
||||
if [ -d "eval" ]; then
|
||||
rm -rf ./eval
|
||||
fi
|
||||
mkdir ./eval
|
||||
cp ../*.py ./eval
|
||||
cp *.sh ./eval
|
||||
cp -r ../src ./eval
|
||||
cd ./eval || exit
|
||||
env >env.log
|
||||
echo "start evaluation for device $DEVICE_ID"
|
||||
python eval.py --dataset_path=$PATH1 --checkpoint_path=$PATH2 &>log &
|
||||
cd ..
|
|
@ -0,0 +1,20 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
CUR_PATH=$(dirname $PWD/$0)
|
||||
cd $CUR_PATH/../ &&
|
||||
python process_data.py &&
|
||||
cd - || exit
|
|
@ -0,0 +1,54 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
if [ $# != 1 ]; then
|
||||
echo "Usage: sh run_standalone_train.sh [DATASET_PATH]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
get_real_path() {
|
||||
if [ "${1:0:1}" == "/" ]; then
|
||||
echo "$1"
|
||||
else
|
||||
echo "$(realpath -m $PWD/$1)"
|
||||
fi
|
||||
}
|
||||
|
||||
PATH1=$(get_real_path $1)
|
||||
|
||||
if [ ! -d $PATH1 ]; then
|
||||
echo "error: DATASET_PATH=$PATH1 is not a directory"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
ulimit -u unlimited
|
||||
export DEVICE_NUM=1
|
||||
export DEVICE_ID=0
|
||||
export RANK_ID=0
|
||||
export RANK_SIZE=1
|
||||
|
||||
if [ -d "train" ]; then
|
||||
rm -rf ./train
|
||||
fi
|
||||
mkdir ./train
|
||||
cp ../*.py ./train
|
||||
cp *.sh ./train
|
||||
cp -r ../src ./train
|
||||
cd ./train || exit
|
||||
echo "start training for device $DEVICE_ID"
|
||||
env >env.log
|
||||
python train.py --dataset=$PATH1 &>log &
|
||||
cd ..
|
|
@ -0,0 +1,31 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Network parameters."""
|
||||
from easydict import EasyDict
|
||||
|
||||
config = EasyDict({
|
||||
"max_captcha_digits": 4,
|
||||
"captcha_width": 160,
|
||||
"captcha_height": 64,
|
||||
"batch_size": 64,
|
||||
"epoch_size": 30,
|
||||
"hidden_size": 512,
|
||||
"learning_rate": 0.01,
|
||||
"momentum": 0.9,
|
||||
"save_checkpoint": True,
|
||||
"save_checkpoint_steps": 98,
|
||||
"keep_checkpoint_max": 30,
|
||||
"save_checkpoint_path": "./",
|
||||
})
|
|
@ -0,0 +1,92 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Dataset preprocessing."""
|
||||
import os
|
||||
import math as m
|
||||
import numpy as np
|
||||
import mindspore.common.dtype as mstype
|
||||
import mindspore.dataset.engine as de
|
||||
import mindspore.dataset.transforms.c_transforms as c
|
||||
import mindspore.dataset.transforms.vision.c_transforms as vc
|
||||
from PIL import Image
|
||||
from src.config import config as cf
|
||||
|
||||
|
||||
class _CaptchaDataset():
|
||||
"""
|
||||
create train or evaluation dataset for warpctc
|
||||
|
||||
Args:
|
||||
img_root_dir(str): root path of images
|
||||
max_captcha_digits(int): max number of digits in images.
|
||||
blank(int): value reserved for blank label, default is 10. When parsing label from image file names, if label
|
||||
length is less than max_captcha_digits, the remaining labels are padding with blank.
|
||||
"""
|
||||
|
||||
def __init__(self, img_root_dir, max_captcha_digits, blank=10):
|
||||
if not os.path.exists(img_root_dir):
|
||||
raise RuntimeError("the input image dir {} is invalid!".format(img_root_dir))
|
||||
self.img_root_dir = img_root_dir
|
||||
self.img_names = [i for i in os.listdir(img_root_dir) if i.endswith('.png')]
|
||||
self.max_captcha_digits = max_captcha_digits
|
||||
self.blank = blank
|
||||
|
||||
def __len__(self):
|
||||
return len(self.img_names)
|
||||
|
||||
def __getitem__(self, item):
|
||||
img_name = self.img_names[item]
|
||||
im = Image.open(os.path.join(self.img_root_dir, img_name))
|
||||
r, g, b = im.split()
|
||||
im = Image.merge("RGB", (b, g, r))
|
||||
image = np.array(im)
|
||||
label_str = os.path.splitext(img_name)[0]
|
||||
label_str = label_str[label_str.find('-') + 1:]
|
||||
label = [int(i) for i in label_str]
|
||||
label.extend([int(self.blank)] * (self.max_captcha_digits - len(label)))
|
||||
label = np.array(label)
|
||||
return image, label
|
||||
|
||||
|
||||
def create_dataset(dataset_path, repeat_num=1, batch_size=1):
|
||||
"""
|
||||
create train or evaluation dataset for warpctc
|
||||
|
||||
Args:
|
||||
dataset_path(int): dataset path
|
||||
repeat_num(int): dataset repetition num, default is 1
|
||||
batch_size(int): batch size of generated dataset, default is 1
|
||||
"""
|
||||
rank_size = int(os.environ.get("RANK_SIZE")) if os.environ.get("RANK_SIZE") else 1
|
||||
rank_id = int(os.environ.get("RANK_ID")) if os.environ.get("RANK_ID") else 0
|
||||
|
||||
dataset = _CaptchaDataset(dataset_path, cf.max_captcha_digits)
|
||||
ds = de.GeneratorDataset(dataset, ["image", "label"], shuffle=True, num_shards=rank_size, shard_id=rank_id)
|
||||
ds.set_dataset_size(m.ceil(len(dataset) / rank_size))
|
||||
image_trans = [
|
||||
vc.Rescale(1.0 / 255.0, 0.0),
|
||||
vc.Normalize([0.9010, 0.9049, 0.9025], std=[0.1521, 0.1347, 0.1458]),
|
||||
vc.Resize((m.ceil(cf.captcha_height / 16) * 16, cf.captcha_width)),
|
||||
vc.HWC2CHW()
|
||||
]
|
||||
label_trans = [
|
||||
c.TypeCast(mstype.int32)
|
||||
]
|
||||
ds = ds.map(input_columns=["image"], num_parallel_workers=8, operations=image_trans)
|
||||
ds = ds.map(input_columns=["label"], num_parallel_workers=8, operations=label_trans)
|
||||
|
||||
ds = ds.batch(batch_size)
|
||||
ds = ds.repeat(repeat_num)
|
||||
return ds
|
|
@ -0,0 +1,49 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""CTC Loss."""
|
||||
import numpy as np
|
||||
from mindspore.nn.loss.loss import _Loss
|
||||
from mindspore import Tensor, Parameter
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.ops import operations as P
|
||||
|
||||
|
||||
class CTCLoss(_Loss):
|
||||
"""
|
||||
CTCLoss definition
|
||||
|
||||
Args:
|
||||
max_sequence_length(int): max number of sequence length. For captcha images, the value is equal to image
|
||||
width
|
||||
max_label_length(int): max number of label length for each input.
|
||||
batch_size(int): batch size of input logits
|
||||
"""
|
||||
|
||||
def __init__(self, max_sequence_length, max_label_length, batch_size):
|
||||
super(CTCLoss, self).__init__()
|
||||
self.sequence_length = Parameter(Tensor(np.array([max_sequence_length] * batch_size), mstype.int32),
|
||||
name="sequence_length")
|
||||
labels_indices = []
|
||||
for i in range(batch_size):
|
||||
for j in range(max_label_length):
|
||||
labels_indices.append([i, j])
|
||||
self.labels_indices = Parameter(Tensor(np.array(labels_indices), mstype.int64), name="labels_indices")
|
||||
self.reshape = P.Reshape()
|
||||
self.ctc_loss = P.CTCLoss(ctc_merge_repeated=True)
|
||||
|
||||
def construct(self, logit, label):
|
||||
labels_values = self.reshape(label, (-1,))
|
||||
loss, _ = self.ctc_loss(logit, self.labels_indices, labels_values, self.sequence_length)
|
||||
return loss
|
|
@ -0,0 +1,36 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Learning rate generator."""
|
||||
|
||||
|
||||
def get_lr(epoch_size, step_size, lr_init):
|
||||
"""
|
||||
generate learning rate for each step, which decays in every 10 epoch
|
||||
|
||||
Args:
|
||||
epoch_size(int): total epoch number
|
||||
step_size(int): total step number in each step
|
||||
lr_init(int): initial learning rate
|
||||
|
||||
Returns:
|
||||
List, learning rate array
|
||||
"""
|
||||
lr = lr_init
|
||||
lrs = []
|
||||
for i in range(1, epoch_size + 1):
|
||||
if i % 10 == 0:
|
||||
lr *= 0.1
|
||||
lrs.extend([lr for _ in range(step_size)])
|
||||
return lrs
|
|
@ -0,0 +1,89 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Metric for accuracy evaluation."""
|
||||
from mindspore import nn
|
||||
|
||||
BLANK_LABLE = 10
|
||||
|
||||
|
||||
class WarpCTCAccuracy(nn.Metric):
|
||||
"""
|
||||
Define accuracy metric for warpctc network.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super(WarpCTCAccuracy).__init__()
|
||||
self._correct_num = 0
|
||||
self._total_num = 0
|
||||
self._count = 0
|
||||
|
||||
def clear(self):
|
||||
self._correct_num = 0
|
||||
self._total_num = 0
|
||||
|
||||
def update(self, *inputs):
|
||||
if len(inputs) != 2:
|
||||
raise ValueError('WarpCTCAccuracy need 2 inputs (y_pred, y), but got {}'.format(len(inputs)))
|
||||
|
||||
y_pred = self._convert_data(inputs[0])
|
||||
y = self._convert_data(inputs[1])
|
||||
|
||||
self._count += 1
|
||||
|
||||
pred_lbls = self._get_prediction(y_pred)
|
||||
|
||||
for b_idx, target in enumerate(y):
|
||||
if self._is_eq(pred_lbls[b_idx], target):
|
||||
self._correct_num += 1
|
||||
self._total_num += 1
|
||||
|
||||
def eval(self):
|
||||
if self._total_num == 0:
|
||||
raise RuntimeError('Accuary can not be calculated, because the number of samples is 0.')
|
||||
return self._correct_num / self._total_num
|
||||
|
||||
@staticmethod
|
||||
def _is_eq(pred_lbl, target):
|
||||
"""
|
||||
check whether predict label is equal to target label
|
||||
"""
|
||||
target = target.tolist()
|
||||
pred_diff = len(target) - len(pred_lbl)
|
||||
if pred_diff > 0:
|
||||
# padding by BLANK_LABLE
|
||||
pred_lbl.extend([BLANK_LABLE] * pred_diff)
|
||||
return pred_lbl == target
|
||||
|
||||
@staticmethod
|
||||
def _get_prediction(y_pred):
|
||||
"""
|
||||
parse predict result to labels
|
||||
"""
|
||||
seq_len, batch_size, _ = y_pred.shape
|
||||
indices = y_pred.argmax(axis=2)
|
||||
|
||||
lens = [seq_len] * batch_size
|
||||
pred_lbls = []
|
||||
for i in range(batch_size):
|
||||
idx = indices[:, i]
|
||||
last_idx = BLANK_LABLE
|
||||
pred_lbl = []
|
||||
for j in range(lens[i]):
|
||||
cur_idx = idx[j]
|
||||
if cur_idx not in [last_idx, BLANK_LABLE]:
|
||||
pred_lbl.append(cur_idx)
|
||||
last_idx = cur_idx
|
||||
pred_lbls.append(pred_lbl)
|
||||
return pred_lbls
|
|
@ -0,0 +1,90 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Warpctc network definition."""
|
||||
|
||||
import numpy as np
|
||||
import mindspore.nn as nn
|
||||
from mindspore import Tensor, Parameter
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import functional as F
|
||||
|
||||
|
||||
class StackedRNN(nn.Cell):
|
||||
"""
|
||||
Define a stacked RNN network which contains two LSTM layers and one full-connect layer.
|
||||
|
||||
Args:
|
||||
input_size(int): Size of time sequence. Usually, the input_size is equal to three times of image height for
|
||||
captcha images.
|
||||
batch_size(int): batch size of input data, default is 64
|
||||
hidden_size(int): the hidden size in LSTM layers, default is 512
|
||||
"""
|
||||
def __init__(self, input_size, batch_size=64, hidden_size=512):
|
||||
super(StackedRNN, self).__init__()
|
||||
self.batch_size = batch_size
|
||||
self.input_size = input_size
|
||||
self.num_classes = 11
|
||||
self.reshape = P.Reshape()
|
||||
self.cast = P.Cast()
|
||||
k = (1 / hidden_size) ** 0.5
|
||||
self.h1 = Tensor(np.zeros(shape=(batch_size, hidden_size)).astype(np.float16))
|
||||
self.c1 = Tensor(np.zeros(shape=(batch_size, hidden_size)).astype(np.float16))
|
||||
self.w1 = Parameter(np.random.uniform(-k, k, (4 * hidden_size, input_size + hidden_size, 1, 1))
|
||||
.astype(np.float16), name="w1")
|
||||
self.w2 = Parameter(np.random.uniform(-k, k, (4 * hidden_size, hidden_size + hidden_size, 1, 1))
|
||||
.astype(np.float16), name="w2")
|
||||
self.b1 = Parameter(np.random.uniform(-k, k, (4 * hidden_size, 1, 1, 1)).astype(np.float16), name="b1")
|
||||
self.b2 = Parameter(np.random.uniform(-k, k, (4 * hidden_size, 1, 1, 1)).astype(np.float16), name="b2")
|
||||
|
||||
self.h2 = Tensor(np.zeros(shape=(batch_size, hidden_size)).astype(np.float16))
|
||||
self.c2 = Tensor(np.zeros(shape=(batch_size, hidden_size)).astype(np.float16))
|
||||
|
||||
self.basic_lstm_cell = P.BasicLSTMCell(keep_prob=1.0, forget_bias=0.0, state_is_tuple=True, activation="tanh")
|
||||
|
||||
self.fc_weight = np.random.random((self.num_classes, hidden_size)).astype(np.float32)
|
||||
self.fc_bias = np.random.random((self.num_classes)).astype(np.float32)
|
||||
|
||||
self.fc = nn.Dense(in_channels=hidden_size, out_channels=self.num_classes, weight_init=Tensor(self.fc_weight),
|
||||
bias_init=Tensor(self.fc_bias))
|
||||
|
||||
self.fc.to_float(mstype.float32)
|
||||
self.expand_dims = P.ExpandDims()
|
||||
self.concat = P.Concat()
|
||||
self.transpose = P.Transpose()
|
||||
|
||||
def construct(self, x):
|
||||
x = self.cast(x, mstype.float16)
|
||||
x = self.transpose(x, (3, 0, 2, 1))
|
||||
x = self.reshape(x, (-1, self.batch_size, self.input_size))
|
||||
h1 = self.h1
|
||||
c1 = self.c1
|
||||
h2 = self.h2
|
||||
c2 = self.c2
|
||||
|
||||
c1, h1, _, _, _, _, _ = self.basic_lstm_cell(x[0, :, :], h1, c1, self.w1, self.b1)
|
||||
c2, h2, _, _, _, _, _ = self.basic_lstm_cell(h1, h2, c2, self.w2, self.b2)
|
||||
|
||||
h2_after_fc = self.fc(h2)
|
||||
output = self.expand_dims(h2_after_fc, 0)
|
||||
for i in range(1, F.shape(x)[0]):
|
||||
c1, h1, _, _, _, _, _ = self.basic_lstm_cell(x[i, :, :], h1, c1, self.w1, self.b1)
|
||||
c2, h2, _, _, _, _, _ = self.basic_lstm_cell(h1, h2, c2, self.w2, self.b2)
|
||||
|
||||
h2_after_fc = self.fc(h2)
|
||||
h2_after_fc = self.expand_dims(h2_after_fc, 0)
|
||||
output = self.concat((output, h2_after_fc))
|
||||
|
||||
return output
|
|
@ -0,0 +1,114 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Automatic differentiation with grad clip."""
|
||||
from mindspore.parallel._utils import (_get_device_num, _get_mirror_mean,
|
||||
_get_parallel_mode)
|
||||
from mindspore.train.parallel_utils import ParallelMode
|
||||
from mindspore.common import dtype as mstype
|
||||
from mindspore.ops import composite as C
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.nn.cell import Cell
|
||||
from mindspore.nn.wrap.grad_reducer import DistributedGradReducer
|
||||
import mindspore.nn as nn
|
||||
from mindspore.common.tensor import Tensor
|
||||
import numpy as np
|
||||
|
||||
compute_norm = C.MultitypeFuncGraph("compute_norm")
|
||||
|
||||
|
||||
@compute_norm.register("Tensor")
|
||||
def _compute_norm(grad):
|
||||
norm = nn.Norm()
|
||||
norm = norm(F.cast(grad, mstype.float32))
|
||||
ret = F.expand_dims(F.cast(norm, mstype.float32), 0)
|
||||
return ret
|
||||
|
||||
|
||||
grad_div = C.MultitypeFuncGraph("grad_div")
|
||||
|
||||
|
||||
@grad_div.register("Tensor", "Tensor")
|
||||
def _grad_div(val, grad):
|
||||
div = P.Div()
|
||||
mul = P.Mul()
|
||||
grad = mul(grad, 10.0)
|
||||
ret = div(grad, val)
|
||||
return ret
|
||||
|
||||
|
||||
class TrainOneStepCellWithGradClip(Cell):
|
||||
"""
|
||||
Network training package class.
|
||||
|
||||
Wraps the network with an optimizer. The resulting Cell be trained with input data and label.
|
||||
Backward graph with grad clip will be created in the construct function to do parameter updating.
|
||||
Different parallel modes are available to run the training.
|
||||
|
||||
Args:
|
||||
network (Cell): The training network.
|
||||
optimizer (Cell): Optimizer for updating the weights.
|
||||
sens (Number): The scaling number to be filled as the input of backpropagation. Default value is 1.0.
|
||||
|
||||
Inputs:
|
||||
- data (Tensor) - Tensor of shape :(N, ...).
|
||||
- label (Tensor) - Tensor of shape :(N, ...).
|
||||
|
||||
Outputs:
|
||||
Tensor, a scalar Tensor with shape :math:`()`.
|
||||
"""
|
||||
|
||||
def __init__(self, network, optimizer, sens=1.0):
|
||||
super(TrainOneStepCellWithGradClip, self).__init__(auto_prefix=False)
|
||||
self.network = network
|
||||
self.network.set_grad()
|
||||
self.network.add_flags(defer_inline=True)
|
||||
self.weights = optimizer.parameters
|
||||
self.optimizer = optimizer
|
||||
self.grad = C.GradOperation('grad', get_by_list=True, sens_param=True)
|
||||
self.sens = sens
|
||||
self.reducer_flag = False
|
||||
self.grad_reducer = None
|
||||
self.hyper_map = C.HyperMap()
|
||||
self.greater = P.Greater()
|
||||
self.select = P.Select()
|
||||
self.norm = nn.Norm(keep_dims=True)
|
||||
self.dtype = P.DType()
|
||||
self.cast = P.Cast()
|
||||
self.concat = P.Concat(axis=0)
|
||||
self.ten = Tensor(np.array([10.0]).astype(np.float32))
|
||||
parallel_mode = _get_parallel_mode()
|
||||
if parallel_mode in (ParallelMode.DATA_PARALLEL, ParallelMode.HYBRID_PARALLEL):
|
||||
self.reducer_flag = True
|
||||
if self.reducer_flag:
|
||||
mean = _get_mirror_mean()
|
||||
degree = _get_device_num()
|
||||
self.grad_reducer = DistributedGradReducer(optimizer.parameters, mean, degree)
|
||||
|
||||
def construct(self, data, label):
|
||||
weights = self.weights
|
||||
loss = self.network(data, label)
|
||||
sens = P.Fill()(P.DType()(loss), P.Shape()(loss), self.sens)
|
||||
grads = self.grad(self.network, weights)(data, label, sens)
|
||||
norm = self.hyper_map(F.partial(compute_norm), grads)
|
||||
norm = self.concat(norm)
|
||||
norm = self.norm(norm)
|
||||
cond = self.greater(norm, self.cast(self.ten, self.dtype(norm)))
|
||||
clip_val = self.select(cond, norm, self.cast(self.ten, self.dtype(norm)))
|
||||
grads = self.hyper_map(F.partial(grad_div, clip_val), grads)
|
||||
if self.reducer_flag:
|
||||
# apply grad reducer on grads
|
||||
grads = self.grad_reducer(grads)
|
||||
return F.depend(loss, self.optimizer(grads))
|
|
@ -0,0 +1,84 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Warpctc training"""
|
||||
import os
|
||||
import math as m
|
||||
import random
|
||||
import argparse
|
||||
import numpy as np
|
||||
import mindspore.nn as nn
|
||||
from mindspore import context
|
||||
from mindspore import dataset as de
|
||||
from mindspore.train.model import Model, ParallelMode
|
||||
from mindspore.nn.wrap import WithLossCell
|
||||
from mindspore.train.callback import TimeMonitor, LossMonitor, CheckpointConfig, ModelCheckpoint
|
||||
from mindspore.communication.management import init
|
||||
|
||||
from src.loss import CTCLoss
|
||||
from src.config import config as cf
|
||||
from src.dataset import create_dataset
|
||||
from src.warpctc import StackedRNN
|
||||
from src.warpctc_for_train import TrainOneStepCellWithGradClip
|
||||
from src.lr_schedule import get_lr
|
||||
|
||||
random.seed(1)
|
||||
np.random.seed(1)
|
||||
de.config.set_seed(1)
|
||||
|
||||
parser = argparse.ArgumentParser(description="Warpctc training")
|
||||
parser.add_argument("--run_distribute", type=bool, default=False, help="Run distribute, default is false.")
|
||||
parser.add_argument('--device_num', type=int, default=1, help='Device num, default is 1.')
|
||||
parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path, default is None')
|
||||
args_opt = parser.parse_args()
|
||||
|
||||
device_id = int(os.getenv('DEVICE_ID'))
|
||||
context.set_context(mode=context.GRAPH_MODE,
|
||||
device_target="Ascend",
|
||||
save_graphs=False,
|
||||
device_id=device_id)
|
||||
|
||||
if __name__ == '__main__':
|
||||
if args_opt.run_distribute:
|
||||
context.reset_auto_parallel_context()
|
||||
context.set_auto_parallel_context(device_num=args_opt.device_num,
|
||||
parallel_mode=ParallelMode.DATA_PARALLEL,
|
||||
mirror_mean=True)
|
||||
init()
|
||||
max_captcha_digits = cf.max_captcha_digits
|
||||
input_size = m.ceil(cf.captcha_height / 64) * 64 * 3
|
||||
# create dataset
|
||||
dataset = create_dataset(dataset_path=args_opt.dataset_path, repeat_num=cf.epoch_size, batch_size=cf.batch_size)
|
||||
step_size = dataset.get_dataset_size()
|
||||
# define lr
|
||||
lr_init = cf.learning_rate if not args_opt.run_distribute else cf.learning_rate * args_opt.device_num
|
||||
lr = get_lr(cf.epoch_size, step_size, lr_init)
|
||||
# define loss
|
||||
loss = CTCLoss(max_sequence_length=cf.captcha_width, max_label_length=max_captcha_digits, batch_size=cf.batch_size)
|
||||
# define net
|
||||
net = StackedRNN(input_size=input_size, batch_size=cf.batch_size, hidden_size=cf.hidden_size)
|
||||
# define opt
|
||||
opt = nn.SGD(params=net.trainable_params(), learning_rate=lr, momentum=cf.momentum)
|
||||
net = WithLossCell(net, loss)
|
||||
net = TrainOneStepCellWithGradClip(net, opt).set_train()
|
||||
# define model
|
||||
model = Model(net)
|
||||
# define callbacks
|
||||
callbacks = [LossMonitor(), TimeMonitor(data_size=step_size)]
|
||||
if cf.save_checkpoint:
|
||||
config_ck = CheckpointConfig(save_checkpoint_steps=cf.save_checkpoint_steps,
|
||||
keep_checkpoint_max=cf.keep_checkpoint_max)
|
||||
ckpt_cb = ModelCheckpoint(prefix="waptctc", directory=cf.save_checkpoint_path, config=config_ck)
|
||||
callbacks.append(ckpt_cb)
|
||||
model.train(cf.epoch_size, dataset, callbacks=callbacks)
|
Loading…
Reference in New Issue