forked from mindspore-Ecosystem/mindspore
!3421 Add WarpCTC GPU script
Merge pull request !3421 from yangyongjie/master
This commit is contained in:
commit
669a8969c7
|
@ -31,7 +31,8 @@ These is an example of training Warpctc with self-generated captcha image datase
|
|||
└──warpct
|
||||
├── README.md
|
||||
├── script
|
||||
├── run_distribute_train.sh # launch distributed training(8 pcs)
|
||||
├── run_distribute_train.sh # launch distributed training in Ascend(8 pcs)
|
||||
├── run_distribute_train_for_gpu.sh # launch distributed training in GPU
|
||||
├── run_eval.sh # launch evaluation
|
||||
├── run_process_data.sh # launch dataset generation
|
||||
└── run_standalone_train.sh # launch standalone training(1 pcs)
|
||||
|
@ -75,22 +76,31 @@ Parameters for both training and evaluation can be set in config.py.
|
|||
#### Usage
|
||||
|
||||
```
|
||||
# distributed training
|
||||
# distributed training in Ascend
|
||||
Usage: sh run_distribute_train.sh [MINDSPORE_HCCL_CONFIG_PATH] [DATASET_PATH]
|
||||
|
||||
# distributed training in GPU
|
||||
Usage: sh run_distribute_train_for_gpu.sh [RANK_SIZE] [DATASET_PATH]
|
||||
|
||||
# standalone training
|
||||
Usage: sh run_standalone_train.sh [DATASET_PATH]
|
||||
Usage: sh run_standalone_train.sh [DATASET_PATH] [PLATFORM]
|
||||
```
|
||||
|
||||
|
||||
#### Launch
|
||||
|
||||
```
|
||||
# distribute training example
|
||||
# distribute training example in Ascend
|
||||
sh run_distribute_train.sh rank_table.json ../data/train
|
||||
|
||||
# standalone training example
|
||||
sh run_standalone_train.sh ../data/train
|
||||
# distribute training example in GPU
|
||||
sh run_distribute_train.sh 8 ../data/train
|
||||
|
||||
# standalone training example in Ascend
|
||||
sh run_standalone_train.sh ../data/train Ascend
|
||||
|
||||
# standalone training example in GPU
|
||||
sh run_standalone_train.sh ../data/train GPU
|
||||
```
|
||||
|
||||
> About rank_table.json, you can refer to the [distributed training tutorial](https://www.mindspore.cn/tutorial/en/master/advanced_use/distributed_training.html).
|
||||
|
@ -116,14 +126,17 @@ Epoch: [ 5/ 30], step: [ 98/ 98], loss: [0.0186/0.0186], time: [75199.5809]
|
|||
|
||||
```
|
||||
# evaluation
|
||||
Usage: sh run_eval.sh [DATASET_PATH] [CHECKPOINT_PATH]
|
||||
Usage: sh run_eval.sh [DATASET_PATH] [CHECKPOINT_PATH] [PLATFORM]
|
||||
```
|
||||
|
||||
#### Launch
|
||||
|
||||
```
|
||||
# evaluation example
|
||||
sh run_eval.sh ../data/test warpctc-30-98.ckpt
|
||||
# evaluation example in Ascend
|
||||
sh run_eval.sh ../data/test warpctc-30-98.ckpt Ascend
|
||||
|
||||
# evaluation example in GPU
|
||||
sh run_eval.sh ../data/test warpctc-30-98.ckpt GPU
|
||||
```
|
||||
|
||||
> checkpoint can be produced in training process.
|
||||
|
|
|
@ -23,10 +23,10 @@ from mindspore import dataset as de
|
|||
from mindspore.train.model import Model
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
|
||||
from src.loss import CTCLoss
|
||||
from src.loss import CTCLoss, CTCLossV2
|
||||
from src.config import config as cf
|
||||
from src.dataset import create_dataset
|
||||
from src.warpctc import StackedRNN
|
||||
from src.warpctc import StackedRNN, StackedRNNForGPU
|
||||
from src.metric import WarpCTCAccuracy
|
||||
|
||||
random.seed(1)
|
||||
|
@ -36,30 +36,38 @@ de.config.set_seed(1)
|
|||
parser = argparse.ArgumentParser(description="Warpctc training")
|
||||
parser.add_argument("--dataset_path", type=str, default=None, help="Dataset, default is None.")
|
||||
parser.add_argument("--checkpoint_path", type=str, default=None, help="checkpoint file path, default is None")
|
||||
parser.add_argument('--platform', type=str, default='Ascend', choices=['Ascend', 'GPU'],
|
||||
help='Running platform, choose from Ascend, GPU, and default is Ascend.')
|
||||
args_opt = parser.parse_args()
|
||||
|
||||
device_id = int(os.getenv('DEVICE_ID'))
|
||||
context.set_context(mode=context.GRAPH_MODE,
|
||||
device_target="Ascend",
|
||||
save_graphs=False,
|
||||
device_id=device_id)
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.platform, save_graphs=False)
|
||||
if args_opt.platform == 'Ascend':
|
||||
device_id = int(os.getenv('DEVICE_ID'))
|
||||
context.set_context(device_id=device_id)
|
||||
|
||||
if __name__ == '__main__':
|
||||
max_captcha_digits = cf.max_captcha_digits
|
||||
input_size = m.ceil(cf.captcha_height / 64) * 64 * 3
|
||||
# create dataset
|
||||
dataset = create_dataset(dataset_path=args_opt.dataset_path, repeat_num=1, batch_size=cf.batch_size)
|
||||
dataset = create_dataset(dataset_path=args_opt.dataset_path,
|
||||
batch_size=cf.batch_size,
|
||||
device_target=args_opt.platform)
|
||||
step_size = dataset.get_dataset_size()
|
||||
# define loss
|
||||
loss = CTCLoss(max_sequence_length=cf.captcha_width, max_label_length=max_captcha_digits, batch_size=cf.batch_size)
|
||||
# define net
|
||||
net = StackedRNN(input_size=input_size, batch_size=cf.batch_size, hidden_size=cf.hidden_size)
|
||||
if args_opt.platform == 'Ascend':
|
||||
loss = CTCLoss(max_sequence_length=cf.captcha_width,
|
||||
max_label_length=max_captcha_digits,
|
||||
batch_size=cf.batch_size)
|
||||
net = StackedRNN(input_size=input_size, batch_size=cf.batch_size, hidden_size=cf.hidden_size)
|
||||
else:
|
||||
loss = CTCLossV2(max_sequence_length=cf.captcha_width, batch_size=cf.batch_size)
|
||||
net = StackedRNNForGPU(input_size=input_size, batch_size=cf.batch_size, hidden_size=cf.hidden_size)
|
||||
|
||||
# load checkpoint
|
||||
param_dict = load_checkpoint(args_opt.checkpoint_path)
|
||||
load_param_into_net(net, param_dict)
|
||||
net.set_train(False)
|
||||
# define model
|
||||
model = Model(net, loss_fn=loss, metrics={'WarpCTCAccuracy': WarpCTCAccuracy()})
|
||||
model = Model(net, loss_fn=loss, metrics={'WarpCTCAccuracy': WarpCTCAccuracy(args_opt.platform)})
|
||||
# start evaluation
|
||||
res = model.eval(dataset)
|
||||
res = model.eval(dataset, dataset_sink_mode=args_opt.platform == 'Ascend')
|
||||
print("result:", res, flush=True)
|
||||
|
|
|
@ -57,6 +57,6 @@ for ((i = 0; i < ${DEVICE_NUM}; i++)); do
|
|||
cd ./train_parallel$i || exit
|
||||
echo "start training for rank $RANK_ID, device $DEVICE_ID"
|
||||
env >env.log
|
||||
python train.py --run_distribute=True --device_num=$DEVICE_NUM --dataset_path=$PATH2 &>log &
|
||||
python train.py --platform=Ascend --dataset_path=$PATH2 --run_distribute > log.txt 2>&1 &
|
||||
cd ..
|
||||
done
|
||||
|
|
|
@ -0,0 +1,52 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
if [ $# != 2 ]; then
|
||||
echo "Usage: sh run_distribute_train.sh [RANK_SIZE] [DATASET_PATH]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
get_real_path() {
|
||||
if [ "${1:0:1}" == "/" ]; then
|
||||
echo "$1"
|
||||
else
|
||||
echo "$(realpath -m $PWD/$1)"
|
||||
fi
|
||||
}
|
||||
|
||||
RANK_SIZE=$1
|
||||
DATASET_PATH=$(get_real_path $2)
|
||||
|
||||
if [ ! -d $DATASET_PATH ]; then
|
||||
echo "error: DATASET_PATH=$DATASET_PATH is not a directory"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ -d "distribute_train" ]; then
|
||||
rm -rf ./distribute_train
|
||||
fi
|
||||
|
||||
mkdir ./distribute_train
|
||||
cp ../*.py ./distribute_train
|
||||
cp -r ../src ./distribute_train
|
||||
cd ./distribute_train || exit
|
||||
|
||||
mpirun --allow-run-as-root -n $RANK_SIZE \
|
||||
python train.py \
|
||||
--dataset_path=$DATASET_PATH \
|
||||
--platform=GPU \
|
||||
--run_distribute > log.txt 2>&1 &
|
||||
cd ..
|
|
@ -14,8 +14,8 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
if [ $# != 2 ]; then
|
||||
echo "Usage: sh run_eval.sh [DATASET_PATH] [CHECKPOINT_PATH]"
|
||||
if [ $# != 3 ]; then
|
||||
echo "Usage: sh run_eval.sh [DATASET_PATH] [CHECKPOINT_PATH] [PLATFORM]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
@ -29,6 +29,7 @@ get_real_path() {
|
|||
|
||||
PATH1=$(get_real_path $1)
|
||||
PATH2=$(get_real_path $2)
|
||||
PLATFORM=$3
|
||||
|
||||
if [ ! -d $PATH1 ]; then
|
||||
echo "error: DATASET_PATH=$PATH1 is not a directory"
|
||||
|
@ -40,21 +41,44 @@ if [ ! -f $PATH2 ]; then
|
|||
exit 1
|
||||
fi
|
||||
|
||||
ulimit -u unlimited
|
||||
export DEVICE_NUM=1
|
||||
export DEVICE_ID=0
|
||||
export RANK_SIZE=$DEVICE_NUM
|
||||
export RANK_ID=0
|
||||
run_ascend() {
|
||||
ulimit -u unlimited
|
||||
export DEVICE_NUM=1
|
||||
export DEVICE_ID=0
|
||||
export RANK_SIZE=$DEVICE_NUM
|
||||
export RANK_ID=0
|
||||
|
||||
if [ -d "eval" ]; then
|
||||
rm -rf ./eval
|
||||
if [ -d "eval" ]; then
|
||||
rm -rf ./eval
|
||||
fi
|
||||
mkdir ./eval
|
||||
cp ../*.py ./eval
|
||||
cp -r ../src ./eval
|
||||
cd ./eval || exit
|
||||
env >env.log
|
||||
echo "start evaluation for device $DEVICE_ID"
|
||||
python eval.py --dataset_path=$1 --checkpoint_path=$2 --platform=Ascend > log.txt 2>&1 &
|
||||
cd ..
|
||||
}
|
||||
|
||||
run_gpu() {
|
||||
if [ -d "eval" ]; then
|
||||
rm -rf ./eval
|
||||
fi
|
||||
mkdir ./eval
|
||||
cp ../*.py ./eval
|
||||
cp -r ../src ./eval
|
||||
cd ./eval || exit
|
||||
env >env.log
|
||||
python eval.py --dataset_path=$1 --checkpoint_path=$2 --platform=GPU > log.txt 2>&1 &
|
||||
cd ..
|
||||
}
|
||||
|
||||
if [ "Ascend" == $PLATFORM ]; then
|
||||
run_ascend $PATH1 $PATH2
|
||||
elif [ "GPU" == $PLATFORM ]; then
|
||||
run_gpu $PATH1 $PATH2
|
||||
else
|
||||
echo "error: PLATFORM=$PLATFORM is not support, only support Ascend and GPU."
|
||||
fi
|
||||
mkdir ./eval
|
||||
cp ../*.py ./eval
|
||||
cp *.sh ./eval
|
||||
cp -r ../src ./eval
|
||||
cd ./eval || exit
|
||||
env >env.log
|
||||
echo "start evaluation for device $DEVICE_ID"
|
||||
python eval.py --dataset_path=$PATH1 --checkpoint_path=$PATH2 &>log &
|
||||
cd ..
|
||||
|
||||
|
|
|
@ -14,8 +14,8 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
if [ $# != 1 ]; then
|
||||
echo "Usage: sh run_standalone_train.sh [DATASET_PATH]"
|
||||
if [ $# != 2 ]; then
|
||||
echo "Usage: sh run_standalone_train.sh [DATASET_PATH] [PLATFORM]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
@ -28,27 +28,44 @@ get_real_path() {
|
|||
}
|
||||
|
||||
PATH1=$(get_real_path $1)
|
||||
PLATFORM=$2
|
||||
|
||||
if [ ! -d $PATH1 ]; then
|
||||
echo "error: DATASET_PATH=$PATH1 is not a directory"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
ulimit -u unlimited
|
||||
export DEVICE_NUM=1
|
||||
export DEVICE_ID=0
|
||||
export RANK_ID=0
|
||||
export RANK_SIZE=1
|
||||
run_ascend() {
|
||||
ulimit -u unlimited
|
||||
export DEVICE_NUM=1
|
||||
export DEVICE_ID=0
|
||||
export RANK_ID=0
|
||||
export RANK_SIZE=1
|
||||
|
||||
echo "start training for device $DEVICE_ID"
|
||||
env >env.log
|
||||
python train.py --dataset_path=$1 --platform=Ascend > log.txt 2>&1 &
|
||||
cd ..
|
||||
}
|
||||
|
||||
run_gpu() {
|
||||
env >env.log
|
||||
python train.py --dataset_path=$1 --platform=GPU > log.txt 2>&1 &
|
||||
cd ..
|
||||
}
|
||||
|
||||
if [ -d "train" ]; then
|
||||
rm -rf ./train
|
||||
rm -rf ./train
|
||||
fi
|
||||
mkdir ./train
|
||||
cp ../*.py ./train
|
||||
cp *.sh ./train
|
||||
cp -r ../src ./train
|
||||
cd ./train || exit
|
||||
echo "start training for device $DEVICE_ID"
|
||||
env >env.log
|
||||
python train.py --dataset=$PATH1 &>log &
|
||||
cd ..
|
||||
|
||||
if [ "Ascend" == $PLATFORM ]; then
|
||||
run_ascend $PATH1
|
||||
elif [ "GPU" == $PLATFORM ]; then
|
||||
run_gpu $PATH1
|
||||
else
|
||||
echo "error: PLATFORM=$PLATFORM is not support, only support Ascend and GPU."
|
||||
fi
|
|
@ -24,24 +24,25 @@ from PIL import Image
|
|||
from src.config import config as cf
|
||||
|
||||
|
||||
class _CaptchaDataset():
|
||||
class _CaptchaDataset:
|
||||
"""
|
||||
create train or evaluation dataset for warpctc
|
||||
|
||||
Args:
|
||||
img_root_dir(str): root path of images
|
||||
max_captcha_digits(int): max number of digits in images.
|
||||
blank(int): value reserved for blank label, default is 10. When parsing label from image file names, if label
|
||||
length is less than max_captcha_digits, the remaining labels are padding with blank.
|
||||
device_target(str): platform of training, support Ascend and GPU.
|
||||
"""
|
||||
|
||||
def __init__(self, img_root_dir, max_captcha_digits, blank=10):
|
||||
def __init__(self, img_root_dir, max_captcha_digits, device_target='Ascend'):
|
||||
if not os.path.exists(img_root_dir):
|
||||
raise RuntimeError("the input image dir {} is invalid!".format(img_root_dir))
|
||||
self.img_root_dir = img_root_dir
|
||||
self.img_names = [i for i in os.listdir(img_root_dir) if i.endswith('.png')]
|
||||
self.max_captcha_digits = max_captcha_digits
|
||||
self.blank = blank
|
||||
self.target = device_target
|
||||
self.blank = 10 if self.target == 'Ascend' else 0
|
||||
self.label_length = [len(os.path.splitext(n)[0].split('-')[-1]) for n in self.img_names]
|
||||
|
||||
def __len__(self):
|
||||
return len(self.img_names)
|
||||
|
@ -54,27 +55,33 @@ class _CaptchaDataset():
|
|||
image = np.array(im)
|
||||
label_str = os.path.splitext(img_name)[0]
|
||||
label_str = label_str[label_str.find('-') + 1:]
|
||||
label = [int(i) for i in label_str]
|
||||
label.extend([int(self.blank)] * (self.max_captcha_digits - len(label)))
|
||||
if self.target == 'Ascend':
|
||||
label = [int(i) for i in label_str]
|
||||
label.extend([int(self.blank)] * (self.max_captcha_digits - len(label)))
|
||||
else:
|
||||
label = [int(i) + 1 for i in label_str]
|
||||
length = len(label)
|
||||
label.extend([int(self.blank)] * (self.max_captcha_digits - len(label)))
|
||||
label.append(length)
|
||||
label = np.array(label)
|
||||
return image, label
|
||||
|
||||
|
||||
def create_dataset(dataset_path, repeat_num=1, batch_size=1):
|
||||
def create_dataset(dataset_path, batch_size=1, num_shards=1, shard_id=0, device_target='Ascend'):
|
||||
"""
|
||||
create train or evaluation dataset for warpctc
|
||||
|
||||
Args:
|
||||
dataset_path(int): dataset path
|
||||
repeat_num(int): dataset repetition num, default is 1
|
||||
batch_size(int): batch size of generated dataset, default is 1
|
||||
num_shards(int): number of devices
|
||||
shard_id(int): rank id
|
||||
device_target(str): platform of training, support Ascend and GPU
|
||||
"""
|
||||
rank_size = int(os.environ.get("RANK_SIZE")) if os.environ.get("RANK_SIZE") else 1
|
||||
rank_id = int(os.environ.get("RANK_ID")) if os.environ.get("RANK_ID") else 0
|
||||
|
||||
dataset = _CaptchaDataset(dataset_path, cf.max_captcha_digits)
|
||||
ds = de.GeneratorDataset(dataset, ["image", "label"], shuffle=True, num_shards=rank_size, shard_id=rank_id)
|
||||
ds.set_dataset_size(m.ceil(len(dataset) / rank_size))
|
||||
dataset = _CaptchaDataset(dataset_path, cf.max_captcha_digits, device_target)
|
||||
ds = de.GeneratorDataset(dataset, ["image", "label"], shuffle=True, num_shards=num_shards, shard_id=shard_id)
|
||||
ds.set_dataset_size(m.ceil(len(dataset) / num_shards))
|
||||
image_trans = [
|
||||
vc.Rescale(1.0 / 255.0, 0.0),
|
||||
vc.Normalize([0.9010, 0.9049, 0.9025], std=[0.1521, 0.1347, 0.1458]),
|
||||
|
@ -87,6 +94,5 @@ def create_dataset(dataset_path, repeat_num=1, batch_size=1):
|
|||
ds = ds.map(input_columns=["image"], num_parallel_workers=8, operations=image_trans)
|
||||
ds = ds.map(input_columns=["label"], num_parallel_workers=8, operations=label_trans)
|
||||
|
||||
ds = ds.batch(batch_size)
|
||||
ds = ds.repeat(repeat_num)
|
||||
ds = ds.batch(batch_size, drop_remainder=True)
|
||||
return ds
|
||||
|
|
|
@ -47,3 +47,25 @@ class CTCLoss(_Loss):
|
|||
labels_values = self.reshape(label, (-1,))
|
||||
loss, _ = self.ctc_loss(logit, self.labels_indices, labels_values, self.sequence_length)
|
||||
return loss
|
||||
|
||||
|
||||
class CTCLossV2(_Loss):
|
||||
"""
|
||||
CTCLoss definition
|
||||
|
||||
Args:
|
||||
max_sequence_length(int): max number of sequence length. For captcha images, the value is equal to image width
|
||||
batch_size(int): batch size of input logits
|
||||
"""
|
||||
|
||||
def __init__(self, max_sequence_length, batch_size):
|
||||
super(CTCLossV2, self).__init__()
|
||||
self.input_length = Tensor(np.array([max_sequence_length] * batch_size), mstype.int32)
|
||||
self.reshape = P.Reshape()
|
||||
self.ctc_loss = P.CTCLossV2()
|
||||
|
||||
def construct(self, logit, label):
|
||||
labels_values = label[:, :-1]
|
||||
labels_length = label[:, -1]
|
||||
loss, _ = self.ctc_loss(logit, labels_values, self.input_length, labels_length)
|
||||
return loss
|
||||
|
|
|
@ -15,19 +15,19 @@
|
|||
"""Metric for accuracy evaluation."""
|
||||
from mindspore import nn
|
||||
|
||||
BLANK_LABLE = 10
|
||||
|
||||
|
||||
class WarpCTCAccuracy(nn.Metric):
|
||||
"""
|
||||
Define accuracy metric for warpctc network.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(self, device_target='Ascend'):
|
||||
super(WarpCTCAccuracy).__init__()
|
||||
self._correct_num = 0
|
||||
self._total_num = 0
|
||||
self._count = 0
|
||||
self.device_target = device_target
|
||||
self.blank = 10 if device_target == 'Ascend' else 0
|
||||
|
||||
def clear(self):
|
||||
self._correct_num = 0
|
||||
|
@ -39,6 +39,8 @@ class WarpCTCAccuracy(nn.Metric):
|
|||
|
||||
y_pred = self._convert_data(inputs[0])
|
||||
y = self._convert_data(inputs[1])
|
||||
if self.device_target == 'GPU':
|
||||
y = y[:, :-1]
|
||||
|
||||
self._count += 1
|
||||
|
||||
|
@ -54,8 +56,7 @@ class WarpCTCAccuracy(nn.Metric):
|
|||
raise RuntimeError('Accuary can not be calculated, because the number of samples is 0.')
|
||||
return self._correct_num / self._total_num
|
||||
|
||||
@staticmethod
|
||||
def _is_eq(pred_lbl, target):
|
||||
def _is_eq(self, pred_lbl, target):
|
||||
"""
|
||||
check whether predict label is equal to target label
|
||||
"""
|
||||
|
@ -63,11 +64,10 @@ class WarpCTCAccuracy(nn.Metric):
|
|||
pred_diff = len(target) - len(pred_lbl)
|
||||
if pred_diff > 0:
|
||||
# padding by BLANK_LABLE
|
||||
pred_lbl.extend([BLANK_LABLE] * pred_diff)
|
||||
pred_lbl.extend([self.blank] * pred_diff)
|
||||
return pred_lbl == target
|
||||
|
||||
@staticmethod
|
||||
def _get_prediction(y_pred):
|
||||
def _get_prediction(self, y_pred):
|
||||
"""
|
||||
parse predict result to labels
|
||||
"""
|
||||
|
@ -78,11 +78,11 @@ class WarpCTCAccuracy(nn.Metric):
|
|||
pred_lbls = []
|
||||
for i in range(batch_size):
|
||||
idx = indices[:, i]
|
||||
last_idx = BLANK_LABLE
|
||||
last_idx = self.blank
|
||||
pred_lbl = []
|
||||
for j in range(lens[i]):
|
||||
cur_idx = idx[j]
|
||||
if cur_idx not in [last_idx, BLANK_LABLE]:
|
||||
if cur_idx not in [last_idx, self.blank]:
|
||||
pred_lbl.append(cur_idx)
|
||||
last_idx = cur_idx
|
||||
pred_lbls.append(pred_lbl)
|
||||
|
|
|
@ -88,3 +88,52 @@ class StackedRNN(nn.Cell):
|
|||
output = self.concat((output, h2_after_fc))
|
||||
|
||||
return output
|
||||
|
||||
|
||||
class StackedRNNForGPU(nn.Cell):
|
||||
"""
|
||||
Define a stacked RNN network which contains two LSTM layers and one full-connect layer.
|
||||
|
||||
Args:
|
||||
input_size(int): Size of time sequence. Usually, the input_size is equal to three times of image height for
|
||||
captcha images.
|
||||
batch_size(int): batch size of input data, default is 64
|
||||
hidden_size(int): the hidden size in LSTM layers, default is 512
|
||||
num_layer(int): the number of layer of LSTM.
|
||||
"""
|
||||
def __init__(self, input_size, batch_size=64, hidden_size=512, num_layer=2):
|
||||
super(StackedRNNForGPU, self).__init__()
|
||||
self.batch_size = batch_size
|
||||
self.input_size = input_size
|
||||
self.num_classes = 11
|
||||
self.reshape = P.Reshape()
|
||||
self.cast = P.Cast()
|
||||
k = (1 / hidden_size) ** 0.5
|
||||
weight_shape = 4 * hidden_size * (input_size + 3 * hidden_size + 4)
|
||||
self.weight = Parameter(np.random.uniform(-k, k, (weight_shape, 1, 1)).astype(np.float32), name='weight')
|
||||
self.h = Tensor(np.zeros(shape=(num_layer, batch_size, hidden_size)).astype(np.float32))
|
||||
self.c = Tensor(np.zeros(shape=(num_layer, batch_size, hidden_size)).astype(np.float32))
|
||||
|
||||
self.lstm = nn.LSTM(input_size, hidden_size, num_layers=2)
|
||||
self.lstm.weight = self.weight
|
||||
|
||||
self.fc_weight = np.random.random((self.num_classes, hidden_size)).astype(np.float32)
|
||||
self.fc_bias = np.random.random(self.num_classes).astype(np.float32)
|
||||
|
||||
self.fc = nn.Dense(in_channels=hidden_size, out_channels=self.num_classes, weight_init=Tensor(self.fc_weight),
|
||||
bias_init=Tensor(self.fc_bias))
|
||||
|
||||
self.fc.to_float(mstype.float32)
|
||||
self.expand_dims = P.ExpandDims()
|
||||
self.concat = P.Concat()
|
||||
self.transpose = P.Transpose()
|
||||
|
||||
def construct(self, x):
|
||||
x = self.transpose(x, (3, 0, 2, 1))
|
||||
x = self.reshape(x, (-1, self.batch_size, self.input_size))
|
||||
output, _ = self.lstm(x, (self.h, self.c))
|
||||
res = ()
|
||||
for i in range(F.shape(x)[0]):
|
||||
res += (self.expand_dims(self.fc(output[i]), 0),)
|
||||
res = self.concat(res)
|
||||
return res
|
||||
|
|
|
@ -42,7 +42,7 @@ grad_div = C.MultitypeFuncGraph("grad_div")
|
|||
|
||||
@grad_div.register("Tensor", "Tensor")
|
||||
def _grad_div(val, grad):
|
||||
div = P.Div()
|
||||
div = P.RealDiv()
|
||||
mul = P.Mul()
|
||||
grad = mul(grad, 10.0)
|
||||
ret = div(grad, val)
|
||||
|
|
|
@ -24,12 +24,12 @@ from mindspore import dataset as de
|
|||
from mindspore.train.model import Model, ParallelMode
|
||||
from mindspore.nn.wrap import WithLossCell
|
||||
from mindspore.train.callback import TimeMonitor, LossMonitor, CheckpointConfig, ModelCheckpoint
|
||||
from mindspore.communication.management import init
|
||||
from mindspore.communication.management import init, get_group_size, get_rank
|
||||
|
||||
from src.loss import CTCLoss
|
||||
from src.loss import CTCLoss, CTCLossV2
|
||||
from src.config import config as cf
|
||||
from src.dataset import create_dataset
|
||||
from src.warpctc import StackedRNN
|
||||
from src.warpctc import StackedRNN, StackedRNNForGPU
|
||||
from src.warpctc_for_train import TrainOneStepCellWithGradClip
|
||||
from src.lr_schedule import get_lr
|
||||
|
||||
|
@ -38,38 +38,60 @@ np.random.seed(1)
|
|||
de.config.set_seed(1)
|
||||
|
||||
parser = argparse.ArgumentParser(description="Warpctc training")
|
||||
parser.add_argument("--run_distribute", type=bool, default=False, help="Run distribute, default is false.")
|
||||
parser.add_argument('--device_num', type=int, default=1, help='Device num, default is 1.')
|
||||
parser.add_argument("--run_distribute", action='store_true', help="Run distribute, default is false.")
|
||||
parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path, default is None')
|
||||
parser.add_argument('--platform', type=str, default='Ascend', choices=['Ascend', 'GPU'],
|
||||
help='Running platform, choose from Ascend, GPU, and default is Ascend.')
|
||||
parser.set_defaults(run_distribute=False)
|
||||
args_opt = parser.parse_args()
|
||||
|
||||
device_id = int(os.getenv('DEVICE_ID'))
|
||||
context.set_context(mode=context.GRAPH_MODE,
|
||||
device_target="Ascend",
|
||||
save_graphs=False,
|
||||
device_id=device_id)
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.platform, save_graphs=False)
|
||||
if args_opt.platform == 'Ascend':
|
||||
device_id = int(os.getenv('DEVICE_ID'))
|
||||
context.set_context(device_id=device_id)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
lr_scale = 1
|
||||
if args_opt.run_distribute:
|
||||
if args_opt.platform == 'Ascend':
|
||||
init()
|
||||
lr_scale = 1
|
||||
device_num = int(os.environ.get("RANK_SIZE"))
|
||||
rank = int(os.environ.get("RANK_ID"))
|
||||
else:
|
||||
init('nccl')
|
||||
lr_scale = 0.5
|
||||
device_num = get_group_size()
|
||||
rank = get_rank()
|
||||
context.reset_auto_parallel_context()
|
||||
context.set_auto_parallel_context(device_num=args_opt.device_num,
|
||||
context.set_auto_parallel_context(device_num=device_num,
|
||||
parallel_mode=ParallelMode.DATA_PARALLEL,
|
||||
mirror_mean=True)
|
||||
init()
|
||||
else:
|
||||
device_num = 1
|
||||
rank = 0
|
||||
|
||||
max_captcha_digits = cf.max_captcha_digits
|
||||
input_size = m.ceil(cf.captcha_height / 64) * 64 * 3
|
||||
# create dataset
|
||||
dataset = create_dataset(dataset_path=args_opt.dataset_path, repeat_num=1, batch_size=cf.batch_size)
|
||||
dataset = create_dataset(dataset_path=args_opt.dataset_path, batch_size=cf.batch_size,
|
||||
num_shards=device_num, shard_id=rank, device_target=args_opt.platform)
|
||||
step_size = dataset.get_dataset_size()
|
||||
# define lr
|
||||
lr_init = cf.learning_rate if not args_opt.run_distribute else cf.learning_rate * args_opt.device_num
|
||||
lr_init = cf.learning_rate if not args_opt.run_distribute else cf.learning_rate * device_num * lr_scale
|
||||
lr = get_lr(cf.epoch_size, step_size, lr_init)
|
||||
# define loss
|
||||
loss = CTCLoss(max_sequence_length=cf.captcha_width, max_label_length=max_captcha_digits, batch_size=cf.batch_size)
|
||||
# define net
|
||||
net = StackedRNN(input_size=input_size, batch_size=cf.batch_size, hidden_size=cf.hidden_size)
|
||||
# define opt
|
||||
opt = nn.SGD(params=net.trainable_params(), learning_rate=lr, momentum=cf.momentum)
|
||||
if args_opt.platform == 'Ascend':
|
||||
loss = CTCLoss(max_sequence_length=cf.captcha_width,
|
||||
max_label_length=max_captcha_digits,
|
||||
batch_size=cf.batch_size)
|
||||
net = StackedRNN(input_size=input_size, batch_size=cf.batch_size, hidden_size=cf.hidden_size)
|
||||
opt = nn.SGD(params=net.trainable_params(), learning_rate=lr, momentum=cf.momentum)
|
||||
else:
|
||||
loss = CTCLossV2(max_sequence_length=cf.captcha_width, batch_size=cf.batch_size)
|
||||
net = StackedRNNForGPU(input_size=input_size, batch_size=cf.batch_size, hidden_size=cf.hidden_size)
|
||||
opt = nn.Momentum(params=net.trainable_params(), learning_rate=lr, momentum=cf.momentum)
|
||||
|
||||
net = WithLossCell(net, loss)
|
||||
net = TrainOneStepCellWithGradClip(net, opt).set_train()
|
||||
# define model
|
||||
|
@ -79,6 +101,6 @@ if __name__ == '__main__':
|
|||
if cf.save_checkpoint:
|
||||
config_ck = CheckpointConfig(save_checkpoint_steps=cf.save_checkpoint_steps,
|
||||
keep_checkpoint_max=cf.keep_checkpoint_max)
|
||||
ckpt_cb = ModelCheckpoint(prefix="waptctc", directory=cf.save_checkpoint_path, config=config_ck)
|
||||
ckpt_cb = ModelCheckpoint(prefix="warpctc", directory=cf.save_checkpoint_path, config=config_ck)
|
||||
callbacks.append(ckpt_cb)
|
||||
model.train(cf.epoch_size, dataset, callbacks=callbacks)
|
||||
|
|
Loading…
Reference in New Issue