!15759 add warpctc CPU support

From: @zhao_ting_v
Reviewed-by: @wuxuejian,@c_34
Signed-off-by: @wuxuejian,@c_34
This commit is contained in:
mindspore-ci-bot 2021-04-29 17:02:06 +08:00 committed by Gitee
commit fa5648add2
8 changed files with 116 additions and 40 deletions

View File

@ -37,8 +37,8 @@ The dataset is self-generated using a third-party library called [captcha](https
## [Environment Requirements](#contents)
- HardwareAscend/GPU
- Prepare hardware environment with Ascend or GPU processor.
- HardwareAscend/GPU/CPU
- Prepare hardware environment with Ascend, GPU or CPU processor.
- Framework
- [MindSpore](https://gitee.com/mindspore/mindspore)
- For more information, please check the resources below
@ -68,13 +68,13 @@ The dataset is self-generated using a third-party library called [captcha](https
- Running on Ascend
```bash
# distribute training example in Ascend
# distribute training example on Ascend
$ bash run_distribute_train.sh rank_table.json ../data/train
# evaluation example in Ascend
# evaluation example on Ascend
$ bash run_eval.sh ../data/test warpctc-30-97.ckpt Ascend
# standalone training example in Ascend
# standalone training example on Ascend
$ bash run_standalone_train.sh ../data/train Ascend
```
@ -88,16 +88,30 @@ The dataset is self-generated using a third-party library called [captcha](https
- Running on GPU
```bash
# distribute training example in GPU
# distribute training example on GPU
$ bash run_distribute_train_for_gpu.sh 8 ../data/train
# standalone training example in GPU
# standalone training example on GPU
$ bash run_standalone_train.sh ../data/train GPU
# evaluation example in GPU
# evaluation example on GPU
$ bash run_eval.sh ../data/test warpctc-30-97.ckpt GPU
```
- Running on CPU
```bash
# training example on CPU
$ bash run_standalone_train.sh ../data/train CPU
or
python train.py --dataset_path=./data/train --platform=CPU
# evaluation example on CPU
$ bash run_eval.sh ../data/test warpctc-30-97.ckpt CPU
or
python eval.py --dataset_path=./data/test --checkpoint_path=warpctc-30-97.ckpt --platform=CPU
```
## [Script Description](#contents)
### [Script and Sample Code](#contents)

View File

@ -42,8 +42,8 @@ WarpCTC是带有一层FC神经网络的二层堆叠LSTM模型。详细信息请
## 环境要求
- 硬件Ascend/GPU
- 使用Ascend或GPU处理器来搭建硬件环境。
- 硬件Ascend/GPU/CPU
- 使用Ascend,GPU或者CPU处理器来搭建硬件环境。
- 框架
- [MindSpore](https://gitee.com/mindspore/mindspore)
- 如需查看详情,请参见如下资源:
@ -92,7 +92,7 @@ WarpCTC是带有一层FC神经网络的二层堆叠LSTM模型。详细信息请
- 在GPU环境运行
```bash
# Ascend分布式训练示例
# GPU分布式训练示例
$ bash run_distribute_train_for_gpu.sh 8 ../data/train
# GPU单机训练示例
@ -102,6 +102,20 @@ WarpCTC是带有一层FC神经网络的二层堆叠LSTM模型。详细信息请
$ bash run_eval.sh ../data/test warpctc-30-97.ckpt GPU
```
- 在CPU环境运行
```bash
# CPU训练示例
$ bash run_standalone_train.sh ../data/train CPU
或者
python train.py --dataset_path=./data/train --platform=CPU
# CPU评估示例
$ bash run_eval.sh ../data/test warpctc-30-97.ckpt CPU
或者
python eval.py --dataset_path=./data/test --checkpoint_path=warpctc-30-97.ckpt --platform=CPU
```
## 脚本说明
### 脚本及样例代码

View File

@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2020-2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -24,7 +24,7 @@ from mindspore.train.serialization import load_checkpoint, load_param_into_net
from src.loss import CTCLoss
from src.config import config as cf
from src.dataset import create_dataset
from src.warpctc import StackedRNN, StackedRNNForGPU
from src.warpctc import StackedRNN, StackedRNNForGPU, StackedRNNForCPU
from src.metric import WarpCTCAccuracy
set_seed(1)
@ -32,8 +32,8 @@ set_seed(1)
parser = argparse.ArgumentParser(description="Warpctc training")
parser.add_argument("--dataset_path", type=str, default=None, help="Dataset, default is None.")
parser.add_argument("--checkpoint_path", type=str, default=None, help="checkpoint file path, default is None")
parser.add_argument('--platform', type=str, default='Ascend', choices=['Ascend', 'GPU'],
help='Running platform, choose from Ascend, GPU, and default is Ascend.')
parser.add_argument('--platform', type=str, default='Ascend', choices=['Ascend', 'GPU', 'CPU'],
help='Running platform, choose from Ascend, GPU or CPU, and default is Ascend.')
args_opt = parser.parse_args()
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.platform, save_graphs=False)
@ -54,8 +54,10 @@ if __name__ == '__main__':
batch_size=cf.batch_size)
if args_opt.platform == 'Ascend':
net = StackedRNN(input_size=input_size, batch_size=cf.batch_size, hidden_size=cf.hidden_size)
else:
elif args_opt.platform == 'GPU':
net = StackedRNNForGPU(input_size=input_size, batch_size=cf.batch_size, hidden_size=cf.hidden_size)
else:
net = StackedRNNForCPU(input_size=input_size, batch_size=cf.batch_size, hidden_size=cf.hidden_size)
# load checkpoint
param_dict = load_checkpoint(args_opt.checkpoint_path)

View File

@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2020-2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -14,18 +14,19 @@
# ============================================================================
"""export checkpoint file into air models"""
import argparse
import math as m
import numpy as np
from mindspore import Tensor, context, load_checkpoint, load_param_into_net, export
from src.warpctc import StackedRNN
from src.warpctc import StackedRNN, StackedRNNForGPU, StackedRNNForCPU
from src.config import config
parser = argparse.ArgumentParser(description="warpctc_export")
parser.add_argument("--device_id", type=int, default=0, help="Device id")
parser.add_argument("--ckpt_file", type=str, required=True, help="warpctc ckpt file.")
parser.add_argument("--file_name", type=str, default="warpctc", help="warpctc output file name.")
parser.add_argument("--file_format", type=str, choices=["AIR", "ONNX", "MINDIR"], default="MINDIR", help="file format")
parser.add_argument("--file_format", type=str, choices=["AIR", "MINDIR"], default="MINDIR", help="file format")
parser.add_argument("--device_target", type=str, choices=["Ascend", "GPU", "CPU"], default="Ascend",
help="device target")
args = parser.parse_args()
@ -34,15 +35,24 @@ context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
if args.device_target == "Ascend":
context.set_context(device_id=args.device_id)
if args.file_format == "AIR" and args.device_target != "Ascend":
raise ValueError("export AIR must on Ascend")
if __name__ == "__main__":
input_size = m.ceil(config.captcha_height / 64) * 64 * 3
captcha_width = config.captcha_width
captcha_height = config.captcha_height
batch_size = config.batch_size
hidden_size = config.hidden_size
net = StackedRNN(captcha_height * 3, batch_size, hidden_size)
image = Tensor(np.zeros([batch_size, 3, captcha_height, captcha_width], np.float32))
if args.device_target == 'Ascend':
net = StackedRNN(input_size=input_size, batch_size=batch_size, hidden_size=hidden_size)
image = Tensor(np.zeros([batch_size, 3, captcha_height, captcha_width], np.float16))
elif args.device_target == 'GPU':
net = StackedRNNForGPU(input_size=input_size, batch_size=batch_size, hidden_size=hidden_size)
else:
net = StackedRNNForCPU(input_size=input_size, batch_size=batch_size, hidden_size=hidden_size)
param_dict = load_checkpoint(args.ckpt_file)
load_param_into_net(net, param_dict)
net.set_train(False)
image = Tensor(np.zeros([batch_size, 3, captcha_height, captcha_width], np.float16))
export(net, image, file_name=args.file_name, file_format=args.file_format)

View File

@ -1,5 +1,5 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2020-2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -61,7 +61,7 @@ run_ascend() {
cd ..
}
run_gpu() {
run_gpu_cpu() {
if [ -d "eval" ]; then
rm -rf ./eval
fi
@ -70,15 +70,13 @@ run_gpu() {
cp -r ../src ./eval
cd ./eval || exit
env >env.log
python eval.py --dataset_path=$1 --checkpoint_path=$2 --platform=GPU > log.txt 2>&1 &
python eval.py --dataset_path=$1 --checkpoint_path=$2 --platform=$3 > log.txt 2>&1 &
cd ..
}
if [ "Ascend" == $PLATFORM ]; then
run_ascend $PATH1 $PATH2
elif [ "GPU" == $PLATFORM ]; then
run_gpu $PATH1 $PATH2
else
echo "error: PLATFORM=$PLATFORM is not support, only support Ascend and GPU."
run_gpu_cpu $PATH1 $PATH2 $PLATFORM
fi

View File

@ -1,5 +1,5 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2020-2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -48,9 +48,9 @@ run_ascend() {
cd ..
}
run_gpu() {
run_gpu_cpu() {
env >env.log
python train.py --dataset_path=$1 --platform=GPU > log.txt 2>&1 &
python train.py --dataset_path=$1 --platform=$2 > log.txt 2>&1 &
cd ..
}
@ -64,8 +64,6 @@ cd ./train || exit
if [ "Ascend" == $PLATFORM ]; then
run_ascend $PATH1
elif [ "GPU" == $PLATFORM ]; then
run_gpu $PATH1
else
echo "error: PLATFORM=$PLATFORM is not support, only support Ascend and GPU."
run_gpu_cpu $PATH1 $PLATFORM
fi

View File

@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2020-2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -131,3 +131,41 @@ class StackedRNNForGPU(nn.Cell):
res += (self.expand_dims(self.fc(output[i]), 0),)
res = self.concat(res)
return res
class StackedRNNForCPU(nn.Cell):
"""
Define a stacked RNN network which contains two LSTM layers and one full-connect layer on CPU.
Args:
input_size(int): Size of time sequence. Usually, the input_size is equal to three times of image height for
captcha images.
batch_size(int): batch size of input data, default is 64
hidden_size(int): the hidden size in LSTM layers, default is 512
num_classes(int): the number of classes.
"""
def __init__(self, input_size, batch_size=64, hidden_size=512, num_classes=11):
super(StackedRNNForCPU, self).__init__()
self.batch_size = batch_size
self.input_size = input_size
k = (1 / hidden_size) ** 0.5
self.w1 = Parameter(
np.random.uniform(-k, k, (4 * hidden_size * (input_size + hidden_size + 1), 1, 1)).astype(np.float32))
self.w2 = Parameter(
np.random.uniform(-k, k, (4 * hidden_size * (2 * hidden_size + 1), 1, 1)).astype(np.float32))
self.h = Tensor(np.zeros(shape=(1, batch_size, hidden_size)).astype(np.float32))
self.c = Tensor(np.zeros(shape=(1, batch_size, hidden_size)).astype(np.float32))
self.lstm_1 = nn.LSTMCell(input_size=input_size, hidden_size=hidden_size)
self.lstm_2 = nn.LSTMCell(input_size=hidden_size, hidden_size=hidden_size)
self.fc = nn.Dense(in_channels=hidden_size, out_channels=num_classes)
self.transpose = P.Transpose()
def construct(self, x):
x = self.transpose(x, (3, 0, 2, 1))
x = F.reshape(x, (-1, self.batch_size, self.input_size))
y1, _, _, _, _ = self.lstm_1(x, self.h, self.c, self.w1)
y2, _, _, _, _ = self.lstm_2(y1, self.h, self.c, self.w2)
output = self.fc(y2) # y2 shape: [time_step, bs, hidden_size] output shape: [time_step, bs, num_classes].
return output

View File

@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2020-2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -28,7 +28,7 @@ from mindspore.communication.management import init, get_group_size, get_rank
from src.loss import CTCLoss
from src.config import config as cf
from src.dataset import create_dataset
from src.warpctc import StackedRNN, StackedRNNForGPU
from src.warpctc import StackedRNN, StackedRNNForGPU, StackedRNNForCPU
from src.warpctc_for_train import TrainOneStepCellWithGradClip
from src.lr_schedule import get_lr
@ -37,8 +37,8 @@ set_seed(1)
parser = argparse.ArgumentParser(description="Warpctc training")
parser.add_argument("--run_distribute", action='store_true', help="Run distribute, default is false.")
parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path, default is None')
parser.add_argument('--platform', type=str, default='Ascend', choices=['Ascend', 'GPU'],
help='Running platform, choose from Ascend, GPU, and default is Ascend.')
parser.add_argument('--platform', type=str, default='Ascend', choices=['Ascend', 'GPU', 'CPU'],
help='Running platform, choose from Ascend, GPU or CPU, and default is Ascend.')
parser.set_defaults(run_distribute=False)
args_opt = parser.parse_args()
@ -80,8 +80,10 @@ if __name__ == '__main__':
batch_size=cf.batch_size)
if args_opt.platform == 'Ascend':
net = StackedRNN(input_size=input_size, batch_size=cf.batch_size, hidden_size=cf.hidden_size)
else:
elif args_opt.platform == 'GPU':
net = StackedRNNForGPU(input_size=input_size, batch_size=cf.batch_size, hidden_size=cf.hidden_size)
else:
net = StackedRNNForCPU(input_size=input_size, batch_size=cf.batch_size, hidden_size=cf.hidden_size)
opt = nn.SGD(params=net.trainable_params(), learning_rate=lr, momentum=cf.momentum)
net = WithLossCell(net, loss)