!3421 Add WarpCTC GPU script

Merge pull request !3421 from yangyongjie/master
This commit is contained in:
mindspore-ci-bot 2020-07-25 11:37:35 +08:00 committed by Gitee
commit 669a8969c7
12 changed files with 316 additions and 103 deletions

View File

@ -31,7 +31,8 @@ These is an example of training Warpctc with self-generated captcha image datase
└──warpct
├── README.md
├── script
├── run_distribute_train.sh # launch distributed training(8 pcs)
├── run_distribute_train.sh # launch distributed training in Ascend(8 pcs)
├── run_distribute_train_for_gpu.sh # launch distributed training in GPU
├── run_eval.sh # launch evaluation
├── run_process_data.sh # launch dataset generation
└── run_standalone_train.sh # launch standalone training(1 pcs)
@ -75,22 +76,31 @@ Parameters for both training and evaluation can be set in config.py.
#### Usage
```
# distributed training
# distributed training in Ascend
Usage: sh run_distribute_train.sh [MINDSPORE_HCCL_CONFIG_PATH] [DATASET_PATH]
# distributed training in GPU
Usage: sh run_distribute_train_for_gpu.sh [RANK_SIZE] [DATASET_PATH]
# standalone training
Usage: sh run_standalone_train.sh [DATASET_PATH]
Usage: sh run_standalone_train.sh [DATASET_PATH] [PLATFORM]
```
#### Launch
```
# distribute training example
# distribute training example in Ascend
sh run_distribute_train.sh rank_table.json ../data/train
# standalone training example
sh run_standalone_train.sh ../data/train
# distribute training example in GPU
sh run_distribute_train.sh 8 ../data/train
# standalone training example in Ascend
sh run_standalone_train.sh ../data/train Ascend
# standalone training example in GPU
sh run_standalone_train.sh ../data/train GPU
```
> About rank_table.json, you can refer to the [distributed training tutorial](https://www.mindspore.cn/tutorial/en/master/advanced_use/distributed_training.html).
@ -116,14 +126,17 @@ Epoch: [ 5/ 30], step: [ 98/ 98], loss: [0.0186/0.0186], time: [75199.5809]
```
# evaluation
Usage: sh run_eval.sh [DATASET_PATH] [CHECKPOINT_PATH]
Usage: sh run_eval.sh [DATASET_PATH] [CHECKPOINT_PATH] [PLATFORM]
```
#### Launch
```
# evaluation example
sh run_eval.sh ../data/test warpctc-30-98.ckpt
# evaluation example in Ascend
sh run_eval.sh ../data/test warpctc-30-98.ckpt Ascend
# evaluation example in GPU
sh run_eval.sh ../data/test warpctc-30-98.ckpt GPU
```
> checkpoint can be produced in training process.

View File

@ -23,10 +23,10 @@ from mindspore import dataset as de
from mindspore.train.model import Model
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from src.loss import CTCLoss
from src.loss import CTCLoss, CTCLossV2
from src.config import config as cf
from src.dataset import create_dataset
from src.warpctc import StackedRNN
from src.warpctc import StackedRNN, StackedRNNForGPU
from src.metric import WarpCTCAccuracy
random.seed(1)
@ -36,30 +36,38 @@ de.config.set_seed(1)
parser = argparse.ArgumentParser(description="Warpctc training")
parser.add_argument("--dataset_path", type=str, default=None, help="Dataset, default is None.")
parser.add_argument("--checkpoint_path", type=str, default=None, help="checkpoint file path, default is None")
parser.add_argument('--platform', type=str, default='Ascend', choices=['Ascend', 'GPU'],
help='Running platform, choose from Ascend, GPU, and default is Ascend.')
args_opt = parser.parse_args()
device_id = int(os.getenv('DEVICE_ID'))
context.set_context(mode=context.GRAPH_MODE,
device_target="Ascend",
save_graphs=False,
device_id=device_id)
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.platform, save_graphs=False)
if args_opt.platform == 'Ascend':
device_id = int(os.getenv('DEVICE_ID'))
context.set_context(device_id=device_id)
if __name__ == '__main__':
max_captcha_digits = cf.max_captcha_digits
input_size = m.ceil(cf.captcha_height / 64) * 64 * 3
# create dataset
dataset = create_dataset(dataset_path=args_opt.dataset_path, repeat_num=1, batch_size=cf.batch_size)
dataset = create_dataset(dataset_path=args_opt.dataset_path,
batch_size=cf.batch_size,
device_target=args_opt.platform)
step_size = dataset.get_dataset_size()
# define loss
loss = CTCLoss(max_sequence_length=cf.captcha_width, max_label_length=max_captcha_digits, batch_size=cf.batch_size)
# define net
net = StackedRNN(input_size=input_size, batch_size=cf.batch_size, hidden_size=cf.hidden_size)
if args_opt.platform == 'Ascend':
loss = CTCLoss(max_sequence_length=cf.captcha_width,
max_label_length=max_captcha_digits,
batch_size=cf.batch_size)
net = StackedRNN(input_size=input_size, batch_size=cf.batch_size, hidden_size=cf.hidden_size)
else:
loss = CTCLossV2(max_sequence_length=cf.captcha_width, batch_size=cf.batch_size)
net = StackedRNNForGPU(input_size=input_size, batch_size=cf.batch_size, hidden_size=cf.hidden_size)
# load checkpoint
param_dict = load_checkpoint(args_opt.checkpoint_path)
load_param_into_net(net, param_dict)
net.set_train(False)
# define model
model = Model(net, loss_fn=loss, metrics={'WarpCTCAccuracy': WarpCTCAccuracy()})
model = Model(net, loss_fn=loss, metrics={'WarpCTCAccuracy': WarpCTCAccuracy(args_opt.platform)})
# start evaluation
res = model.eval(dataset)
res = model.eval(dataset, dataset_sink_mode=args_opt.platform == 'Ascend')
print("result:", res, flush=True)

View File

@ -57,6 +57,6 @@ for ((i = 0; i < ${DEVICE_NUM}; i++)); do
cd ./train_parallel$i || exit
echo "start training for rank $RANK_ID, device $DEVICE_ID"
env >env.log
python train.py --run_distribute=True --device_num=$DEVICE_NUM --dataset_path=$PATH2 &>log &
python train.py --platform=Ascend --dataset_path=$PATH2 --run_distribute > log.txt 2>&1 &
cd ..
done

View File

@ -0,0 +1,52 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 2 ]; then
echo "Usage: sh run_distribute_train.sh [RANK_SIZE] [DATASET_PATH]"
exit 1
fi
get_real_path() {
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
RANK_SIZE=$1
DATASET_PATH=$(get_real_path $2)
if [ ! -d $DATASET_PATH ]; then
echo "error: DATASET_PATH=$DATASET_PATH is not a directory"
exit 1
fi
if [ -d "distribute_train" ]; then
rm -rf ./distribute_train
fi
mkdir ./distribute_train
cp ../*.py ./distribute_train
cp -r ../src ./distribute_train
cd ./distribute_train || exit
mpirun --allow-run-as-root -n $RANK_SIZE \
python train.py \
--dataset_path=$DATASET_PATH \
--platform=GPU \
--run_distribute > log.txt 2>&1 &
cd ..

View File

@ -14,8 +14,8 @@
# limitations under the License.
# ============================================================================
if [ $# != 2 ]; then
echo "Usage: sh run_eval.sh [DATASET_PATH] [CHECKPOINT_PATH]"
if [ $# != 3 ]; then
echo "Usage: sh run_eval.sh [DATASET_PATH] [CHECKPOINT_PATH] [PLATFORM]"
exit 1
fi
@ -29,6 +29,7 @@ get_real_path() {
PATH1=$(get_real_path $1)
PATH2=$(get_real_path $2)
PLATFORM=$3
if [ ! -d $PATH1 ]; then
echo "error: DATASET_PATH=$PATH1 is not a directory"
@ -40,21 +41,44 @@ if [ ! -f $PATH2 ]; then
exit 1
fi
ulimit -u unlimited
export DEVICE_NUM=1
export DEVICE_ID=0
export RANK_SIZE=$DEVICE_NUM
export RANK_ID=0
run_ascend() {
ulimit -u unlimited
export DEVICE_NUM=1
export DEVICE_ID=0
export RANK_SIZE=$DEVICE_NUM
export RANK_ID=0
if [ -d "eval" ]; then
rm -rf ./eval
if [ -d "eval" ]; then
rm -rf ./eval
fi
mkdir ./eval
cp ../*.py ./eval
cp -r ../src ./eval
cd ./eval || exit
env >env.log
echo "start evaluation for device $DEVICE_ID"
python eval.py --dataset_path=$1 --checkpoint_path=$2 --platform=Ascend > log.txt 2>&1 &
cd ..
}
run_gpu() {
if [ -d "eval" ]; then
rm -rf ./eval
fi
mkdir ./eval
cp ../*.py ./eval
cp -r ../src ./eval
cd ./eval || exit
env >env.log
python eval.py --dataset_path=$1 --checkpoint_path=$2 --platform=GPU > log.txt 2>&1 &
cd ..
}
if [ "Ascend" == $PLATFORM ]; then
run_ascend $PATH1 $PATH2
elif [ "GPU" == $PLATFORM ]; then
run_gpu $PATH1 $PATH2
else
echo "error: PLATFORM=$PLATFORM is not support, only support Ascend and GPU."
fi
mkdir ./eval
cp ../*.py ./eval
cp *.sh ./eval
cp -r ../src ./eval
cd ./eval || exit
env >env.log
echo "start evaluation for device $DEVICE_ID"
python eval.py --dataset_path=$PATH1 --checkpoint_path=$PATH2 &>log &
cd ..

View File

@ -14,8 +14,8 @@
# limitations under the License.
# ============================================================================
if [ $# != 1 ]; then
echo "Usage: sh run_standalone_train.sh [DATASET_PATH]"
if [ $# != 2 ]; then
echo "Usage: sh run_standalone_train.sh [DATASET_PATH] [PLATFORM]"
exit 1
fi
@ -28,27 +28,44 @@ get_real_path() {
}
PATH1=$(get_real_path $1)
PLATFORM=$2
if [ ! -d $PATH1 ]; then
echo "error: DATASET_PATH=$PATH1 is not a directory"
exit 1
fi
ulimit -u unlimited
export DEVICE_NUM=1
export DEVICE_ID=0
export RANK_ID=0
export RANK_SIZE=1
run_ascend() {
ulimit -u unlimited
export DEVICE_NUM=1
export DEVICE_ID=0
export RANK_ID=0
export RANK_SIZE=1
echo "start training for device $DEVICE_ID"
env >env.log
python train.py --dataset_path=$1 --platform=Ascend > log.txt 2>&1 &
cd ..
}
run_gpu() {
env >env.log
python train.py --dataset_path=$1 --platform=GPU > log.txt 2>&1 &
cd ..
}
if [ -d "train" ]; then
rm -rf ./train
rm -rf ./train
fi
mkdir ./train
cp ../*.py ./train
cp *.sh ./train
cp -r ../src ./train
cd ./train || exit
echo "start training for device $DEVICE_ID"
env >env.log
python train.py --dataset=$PATH1 &>log &
cd ..
if [ "Ascend" == $PLATFORM ]; then
run_ascend $PATH1
elif [ "GPU" == $PLATFORM ]; then
run_gpu $PATH1
else
echo "error: PLATFORM=$PLATFORM is not support, only support Ascend and GPU."
fi

View File

@ -24,24 +24,25 @@ from PIL import Image
from src.config import config as cf
class _CaptchaDataset():
class _CaptchaDataset:
"""
create train or evaluation dataset for warpctc
Args:
img_root_dir(str): root path of images
max_captcha_digits(int): max number of digits in images.
blank(int): value reserved for blank label, default is 10. When parsing label from image file names, if label
length is less than max_captcha_digits, the remaining labels are padding with blank.
device_target(str): platform of training, support Ascend and GPU.
"""
def __init__(self, img_root_dir, max_captcha_digits, blank=10):
def __init__(self, img_root_dir, max_captcha_digits, device_target='Ascend'):
if not os.path.exists(img_root_dir):
raise RuntimeError("the input image dir {} is invalid!".format(img_root_dir))
self.img_root_dir = img_root_dir
self.img_names = [i for i in os.listdir(img_root_dir) if i.endswith('.png')]
self.max_captcha_digits = max_captcha_digits
self.blank = blank
self.target = device_target
self.blank = 10 if self.target == 'Ascend' else 0
self.label_length = [len(os.path.splitext(n)[0].split('-')[-1]) for n in self.img_names]
def __len__(self):
return len(self.img_names)
@ -54,27 +55,33 @@ class _CaptchaDataset():
image = np.array(im)
label_str = os.path.splitext(img_name)[0]
label_str = label_str[label_str.find('-') + 1:]
label = [int(i) for i in label_str]
label.extend([int(self.blank)] * (self.max_captcha_digits - len(label)))
if self.target == 'Ascend':
label = [int(i) for i in label_str]
label.extend([int(self.blank)] * (self.max_captcha_digits - len(label)))
else:
label = [int(i) + 1 for i in label_str]
length = len(label)
label.extend([int(self.blank)] * (self.max_captcha_digits - len(label)))
label.append(length)
label = np.array(label)
return image, label
def create_dataset(dataset_path, repeat_num=1, batch_size=1):
def create_dataset(dataset_path, batch_size=1, num_shards=1, shard_id=0, device_target='Ascend'):
"""
create train or evaluation dataset for warpctc
Args:
dataset_path(int): dataset path
repeat_num(int): dataset repetition num, default is 1
batch_size(int): batch size of generated dataset, default is 1
num_shards(int): number of devices
shard_id(int): rank id
device_target(str): platform of training, support Ascend and GPU
"""
rank_size = int(os.environ.get("RANK_SIZE")) if os.environ.get("RANK_SIZE") else 1
rank_id = int(os.environ.get("RANK_ID")) if os.environ.get("RANK_ID") else 0
dataset = _CaptchaDataset(dataset_path, cf.max_captcha_digits)
ds = de.GeneratorDataset(dataset, ["image", "label"], shuffle=True, num_shards=rank_size, shard_id=rank_id)
ds.set_dataset_size(m.ceil(len(dataset) / rank_size))
dataset = _CaptchaDataset(dataset_path, cf.max_captcha_digits, device_target)
ds = de.GeneratorDataset(dataset, ["image", "label"], shuffle=True, num_shards=num_shards, shard_id=shard_id)
ds.set_dataset_size(m.ceil(len(dataset) / num_shards))
image_trans = [
vc.Rescale(1.0 / 255.0, 0.0),
vc.Normalize([0.9010, 0.9049, 0.9025], std=[0.1521, 0.1347, 0.1458]),
@ -87,6 +94,5 @@ def create_dataset(dataset_path, repeat_num=1, batch_size=1):
ds = ds.map(input_columns=["image"], num_parallel_workers=8, operations=image_trans)
ds = ds.map(input_columns=["label"], num_parallel_workers=8, operations=label_trans)
ds = ds.batch(batch_size)
ds = ds.repeat(repeat_num)
ds = ds.batch(batch_size, drop_remainder=True)
return ds

View File

@ -47,3 +47,25 @@ class CTCLoss(_Loss):
labels_values = self.reshape(label, (-1,))
loss, _ = self.ctc_loss(logit, self.labels_indices, labels_values, self.sequence_length)
return loss
class CTCLossV2(_Loss):
"""
CTCLoss definition
Args:
max_sequence_length(int): max number of sequence length. For captcha images, the value is equal to image width
batch_size(int): batch size of input logits
"""
def __init__(self, max_sequence_length, batch_size):
super(CTCLossV2, self).__init__()
self.input_length = Tensor(np.array([max_sequence_length] * batch_size), mstype.int32)
self.reshape = P.Reshape()
self.ctc_loss = P.CTCLossV2()
def construct(self, logit, label):
labels_values = label[:, :-1]
labels_length = label[:, -1]
loss, _ = self.ctc_loss(logit, labels_values, self.input_length, labels_length)
return loss

View File

@ -15,19 +15,19 @@
"""Metric for accuracy evaluation."""
from mindspore import nn
BLANK_LABLE = 10
class WarpCTCAccuracy(nn.Metric):
"""
Define accuracy metric for warpctc network.
"""
def __init__(self):
def __init__(self, device_target='Ascend'):
super(WarpCTCAccuracy).__init__()
self._correct_num = 0
self._total_num = 0
self._count = 0
self.device_target = device_target
self.blank = 10 if device_target == 'Ascend' else 0
def clear(self):
self._correct_num = 0
@ -39,6 +39,8 @@ class WarpCTCAccuracy(nn.Metric):
y_pred = self._convert_data(inputs[0])
y = self._convert_data(inputs[1])
if self.device_target == 'GPU':
y = y[:, :-1]
self._count += 1
@ -54,8 +56,7 @@ class WarpCTCAccuracy(nn.Metric):
raise RuntimeError('Accuary can not be calculated, because the number of samples is 0.')
return self._correct_num / self._total_num
@staticmethod
def _is_eq(pred_lbl, target):
def _is_eq(self, pred_lbl, target):
"""
check whether predict label is equal to target label
"""
@ -63,11 +64,10 @@ class WarpCTCAccuracy(nn.Metric):
pred_diff = len(target) - len(pred_lbl)
if pred_diff > 0:
# padding by BLANK_LABLE
pred_lbl.extend([BLANK_LABLE] * pred_diff)
pred_lbl.extend([self.blank] * pred_diff)
return pred_lbl == target
@staticmethod
def _get_prediction(y_pred):
def _get_prediction(self, y_pred):
"""
parse predict result to labels
"""
@ -78,11 +78,11 @@ class WarpCTCAccuracy(nn.Metric):
pred_lbls = []
for i in range(batch_size):
idx = indices[:, i]
last_idx = BLANK_LABLE
last_idx = self.blank
pred_lbl = []
for j in range(lens[i]):
cur_idx = idx[j]
if cur_idx not in [last_idx, BLANK_LABLE]:
if cur_idx not in [last_idx, self.blank]:
pred_lbl.append(cur_idx)
last_idx = cur_idx
pred_lbls.append(pred_lbl)

View File

@ -88,3 +88,52 @@ class StackedRNN(nn.Cell):
output = self.concat((output, h2_after_fc))
return output
class StackedRNNForGPU(nn.Cell):
"""
Define a stacked RNN network which contains two LSTM layers and one full-connect layer.
Args:
input_size(int): Size of time sequence. Usually, the input_size is equal to three times of image height for
captcha images.
batch_size(int): batch size of input data, default is 64
hidden_size(int): the hidden size in LSTM layers, default is 512
num_layer(int): the number of layer of LSTM.
"""
def __init__(self, input_size, batch_size=64, hidden_size=512, num_layer=2):
super(StackedRNNForGPU, self).__init__()
self.batch_size = batch_size
self.input_size = input_size
self.num_classes = 11
self.reshape = P.Reshape()
self.cast = P.Cast()
k = (1 / hidden_size) ** 0.5
weight_shape = 4 * hidden_size * (input_size + 3 * hidden_size + 4)
self.weight = Parameter(np.random.uniform(-k, k, (weight_shape, 1, 1)).astype(np.float32), name='weight')
self.h = Tensor(np.zeros(shape=(num_layer, batch_size, hidden_size)).astype(np.float32))
self.c = Tensor(np.zeros(shape=(num_layer, batch_size, hidden_size)).astype(np.float32))
self.lstm = nn.LSTM(input_size, hidden_size, num_layers=2)
self.lstm.weight = self.weight
self.fc_weight = np.random.random((self.num_classes, hidden_size)).astype(np.float32)
self.fc_bias = np.random.random(self.num_classes).astype(np.float32)
self.fc = nn.Dense(in_channels=hidden_size, out_channels=self.num_classes, weight_init=Tensor(self.fc_weight),
bias_init=Tensor(self.fc_bias))
self.fc.to_float(mstype.float32)
self.expand_dims = P.ExpandDims()
self.concat = P.Concat()
self.transpose = P.Transpose()
def construct(self, x):
x = self.transpose(x, (3, 0, 2, 1))
x = self.reshape(x, (-1, self.batch_size, self.input_size))
output, _ = self.lstm(x, (self.h, self.c))
res = ()
for i in range(F.shape(x)[0]):
res += (self.expand_dims(self.fc(output[i]), 0),)
res = self.concat(res)
return res

View File

@ -42,7 +42,7 @@ grad_div = C.MultitypeFuncGraph("grad_div")
@grad_div.register("Tensor", "Tensor")
def _grad_div(val, grad):
div = P.Div()
div = P.RealDiv()
mul = P.Mul()
grad = mul(grad, 10.0)
ret = div(grad, val)

View File

@ -24,12 +24,12 @@ from mindspore import dataset as de
from mindspore.train.model import Model, ParallelMode
from mindspore.nn.wrap import WithLossCell
from mindspore.train.callback import TimeMonitor, LossMonitor, CheckpointConfig, ModelCheckpoint
from mindspore.communication.management import init
from mindspore.communication.management import init, get_group_size, get_rank
from src.loss import CTCLoss
from src.loss import CTCLoss, CTCLossV2
from src.config import config as cf
from src.dataset import create_dataset
from src.warpctc import StackedRNN
from src.warpctc import StackedRNN, StackedRNNForGPU
from src.warpctc_for_train import TrainOneStepCellWithGradClip
from src.lr_schedule import get_lr
@ -38,38 +38,60 @@ np.random.seed(1)
de.config.set_seed(1)
parser = argparse.ArgumentParser(description="Warpctc training")
parser.add_argument("--run_distribute", type=bool, default=False, help="Run distribute, default is false.")
parser.add_argument('--device_num', type=int, default=1, help='Device num, default is 1.')
parser.add_argument("--run_distribute", action='store_true', help="Run distribute, default is false.")
parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path, default is None')
parser.add_argument('--platform', type=str, default='Ascend', choices=['Ascend', 'GPU'],
help='Running platform, choose from Ascend, GPU, and default is Ascend.')
parser.set_defaults(run_distribute=False)
args_opt = parser.parse_args()
device_id = int(os.getenv('DEVICE_ID'))
context.set_context(mode=context.GRAPH_MODE,
device_target="Ascend",
save_graphs=False,
device_id=device_id)
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.platform, save_graphs=False)
if args_opt.platform == 'Ascend':
device_id = int(os.getenv('DEVICE_ID'))
context.set_context(device_id=device_id)
if __name__ == '__main__':
lr_scale = 1
if args_opt.run_distribute:
if args_opt.platform == 'Ascend':
init()
lr_scale = 1
device_num = int(os.environ.get("RANK_SIZE"))
rank = int(os.environ.get("RANK_ID"))
else:
init('nccl')
lr_scale = 0.5
device_num = get_group_size()
rank = get_rank()
context.reset_auto_parallel_context()
context.set_auto_parallel_context(device_num=args_opt.device_num,
context.set_auto_parallel_context(device_num=device_num,
parallel_mode=ParallelMode.DATA_PARALLEL,
mirror_mean=True)
init()
else:
device_num = 1
rank = 0
max_captcha_digits = cf.max_captcha_digits
input_size = m.ceil(cf.captcha_height / 64) * 64 * 3
# create dataset
dataset = create_dataset(dataset_path=args_opt.dataset_path, repeat_num=1, batch_size=cf.batch_size)
dataset = create_dataset(dataset_path=args_opt.dataset_path, batch_size=cf.batch_size,
num_shards=device_num, shard_id=rank, device_target=args_opt.platform)
step_size = dataset.get_dataset_size()
# define lr
lr_init = cf.learning_rate if not args_opt.run_distribute else cf.learning_rate * args_opt.device_num
lr_init = cf.learning_rate if not args_opt.run_distribute else cf.learning_rate * device_num * lr_scale
lr = get_lr(cf.epoch_size, step_size, lr_init)
# define loss
loss = CTCLoss(max_sequence_length=cf.captcha_width, max_label_length=max_captcha_digits, batch_size=cf.batch_size)
# define net
net = StackedRNN(input_size=input_size, batch_size=cf.batch_size, hidden_size=cf.hidden_size)
# define opt
opt = nn.SGD(params=net.trainable_params(), learning_rate=lr, momentum=cf.momentum)
if args_opt.platform == 'Ascend':
loss = CTCLoss(max_sequence_length=cf.captcha_width,
max_label_length=max_captcha_digits,
batch_size=cf.batch_size)
net = StackedRNN(input_size=input_size, batch_size=cf.batch_size, hidden_size=cf.hidden_size)
opt = nn.SGD(params=net.trainable_params(), learning_rate=lr, momentum=cf.momentum)
else:
loss = CTCLossV2(max_sequence_length=cf.captcha_width, batch_size=cf.batch_size)
net = StackedRNNForGPU(input_size=input_size, batch_size=cf.batch_size, hidden_size=cf.hidden_size)
opt = nn.Momentum(params=net.trainable_params(), learning_rate=lr, momentum=cf.momentum)
net = WithLossCell(net, loss)
net = TrainOneStepCellWithGradClip(net, opt).set_train()
# define model
@ -79,6 +101,6 @@ if __name__ == '__main__':
if cf.save_checkpoint:
config_ck = CheckpointConfig(save_checkpoint_steps=cf.save_checkpoint_steps,
keep_checkpoint_max=cf.keep_checkpoint_max)
ckpt_cb = ModelCheckpoint(prefix="waptctc", directory=cf.save_checkpoint_path, config=config_ck)
ckpt_cb = ModelCheckpoint(prefix="warpctc", directory=cf.save_checkpoint_path, config=config_ck)
callbacks.append(ckpt_cb)
model.train(cf.epoch_size, dataset, callbacks=callbacks)