forked from mindspore-Ecosystem/mindspore
!15759 add warpctc CPU support
From: @zhao_ting_v Reviewed-by: @wuxuejian,@c_34 Signed-off-by: @wuxuejian,@c_34
This commit is contained in:
commit
fa5648add2
|
@ -37,8 +37,8 @@ The dataset is self-generated using a third-party library called [captcha](https
|
||||||
|
|
||||||
## [Environment Requirements](#contents)
|
## [Environment Requirements](#contents)
|
||||||
|
|
||||||
- Hardware(Ascend/GPU)
|
- Hardware(Ascend/GPU/CPU)
|
||||||
- Prepare hardware environment with Ascend or GPU processor.
|
- Prepare hardware environment with Ascend, GPU or CPU processor.
|
||||||
- Framework
|
- Framework
|
||||||
- [MindSpore](https://gitee.com/mindspore/mindspore)
|
- [MindSpore](https://gitee.com/mindspore/mindspore)
|
||||||
- For more information, please check the resources below:
|
- For more information, please check the resources below:
|
||||||
|
@ -68,13 +68,13 @@ The dataset is self-generated using a third-party library called [captcha](https
|
||||||
- Running on Ascend
|
- Running on Ascend
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# distribute training example in Ascend
|
# distribute training example on Ascend
|
||||||
$ bash run_distribute_train.sh rank_table.json ../data/train
|
$ bash run_distribute_train.sh rank_table.json ../data/train
|
||||||
|
|
||||||
# evaluation example in Ascend
|
# evaluation example on Ascend
|
||||||
$ bash run_eval.sh ../data/test warpctc-30-97.ckpt Ascend
|
$ bash run_eval.sh ../data/test warpctc-30-97.ckpt Ascend
|
||||||
|
|
||||||
# standalone training example in Ascend
|
# standalone training example on Ascend
|
||||||
$ bash run_standalone_train.sh ../data/train Ascend
|
$ bash run_standalone_train.sh ../data/train Ascend
|
||||||
|
|
||||||
```
|
```
|
||||||
|
@ -88,16 +88,30 @@ The dataset is self-generated using a third-party library called [captcha](https
|
||||||
- Running on GPU
|
- Running on GPU
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# distribute training example in GPU
|
# distribute training example on GPU
|
||||||
$ bash run_distribute_train_for_gpu.sh 8 ../data/train
|
$ bash run_distribute_train_for_gpu.sh 8 ../data/train
|
||||||
|
|
||||||
# standalone training example in GPU
|
# standalone training example on GPU
|
||||||
$ bash run_standalone_train.sh ../data/train GPU
|
$ bash run_standalone_train.sh ../data/train GPU
|
||||||
|
|
||||||
# evaluation example in GPU
|
# evaluation example on GPU
|
||||||
$ bash run_eval.sh ../data/test warpctc-30-97.ckpt GPU
|
$ bash run_eval.sh ../data/test warpctc-30-97.ckpt GPU
|
||||||
```
|
```
|
||||||
|
|
||||||
|
- Running on CPU
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# training example on CPU
|
||||||
|
$ bash run_standalone_train.sh ../data/train CPU
|
||||||
|
or
|
||||||
|
python train.py --dataset_path=./data/train --platform=CPU
|
||||||
|
|
||||||
|
# evaluation example on CPU
|
||||||
|
$ bash run_eval.sh ../data/test warpctc-30-97.ckpt CPU
|
||||||
|
or
|
||||||
|
python eval.py --dataset_path=./data/test --checkpoint_path=warpctc-30-97.ckpt --platform=CPU
|
||||||
|
```
|
||||||
|
|
||||||
## [Script Description](#contents)
|
## [Script Description](#contents)
|
||||||
|
|
||||||
### [Script and Sample Code](#contents)
|
### [Script and Sample Code](#contents)
|
||||||
|
|
|
@ -42,8 +42,8 @@ WarpCTC是带有一层FC神经网络的二层堆叠LSTM模型。详细信息请
|
||||||
|
|
||||||
## 环境要求
|
## 环境要求
|
||||||
|
|
||||||
- 硬件(Ascend/GPU)
|
- 硬件(Ascend/GPU/CPU)
|
||||||
- 使用Ascend或GPU处理器来搭建硬件环境。
|
- 使用Ascend,GPU或者CPU处理器来搭建硬件环境。
|
||||||
- 框架
|
- 框架
|
||||||
- [MindSpore](https://gitee.com/mindspore/mindspore)
|
- [MindSpore](https://gitee.com/mindspore/mindspore)
|
||||||
- 如需查看详情,请参见如下资源:
|
- 如需查看详情,请参见如下资源:
|
||||||
|
@ -92,7 +92,7 @@ WarpCTC是带有一层FC神经网络的二层堆叠LSTM模型。详细信息请
|
||||||
- 在GPU环境运行
|
- 在GPU环境运行
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# Ascend分布式训练示例
|
# GPU分布式训练示例
|
||||||
$ bash run_distribute_train_for_gpu.sh 8 ../data/train
|
$ bash run_distribute_train_for_gpu.sh 8 ../data/train
|
||||||
|
|
||||||
# GPU单机训练示例
|
# GPU单机训练示例
|
||||||
|
@ -102,6 +102,20 @@ WarpCTC是带有一层FC神经网络的二层堆叠LSTM模型。详细信息请
|
||||||
$ bash run_eval.sh ../data/test warpctc-30-97.ckpt GPU
|
$ bash run_eval.sh ../data/test warpctc-30-97.ckpt GPU
|
||||||
```
|
```
|
||||||
|
|
||||||
|
- 在CPU环境运行
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# CPU训练示例
|
||||||
|
$ bash run_standalone_train.sh ../data/train CPU
|
||||||
|
或者
|
||||||
|
python train.py --dataset_path=./data/train --platform=CPU
|
||||||
|
|
||||||
|
# CPU评估示例
|
||||||
|
$ bash run_eval.sh ../data/test warpctc-30-97.ckpt CPU
|
||||||
|
或者
|
||||||
|
python eval.py --dataset_path=./data/test --checkpoint_path=warpctc-30-97.ckpt --platform=CPU
|
||||||
|
```
|
||||||
|
|
||||||
## 脚本说明
|
## 脚本说明
|
||||||
|
|
||||||
### 脚本及样例代码
|
### 脚本及样例代码
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
# Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
|
@ -24,7 +24,7 @@ from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||||
from src.loss import CTCLoss
|
from src.loss import CTCLoss
|
||||||
from src.config import config as cf
|
from src.config import config as cf
|
||||||
from src.dataset import create_dataset
|
from src.dataset import create_dataset
|
||||||
from src.warpctc import StackedRNN, StackedRNNForGPU
|
from src.warpctc import StackedRNN, StackedRNNForGPU, StackedRNNForCPU
|
||||||
from src.metric import WarpCTCAccuracy
|
from src.metric import WarpCTCAccuracy
|
||||||
|
|
||||||
set_seed(1)
|
set_seed(1)
|
||||||
|
@ -32,8 +32,8 @@ set_seed(1)
|
||||||
parser = argparse.ArgumentParser(description="Warpctc training")
|
parser = argparse.ArgumentParser(description="Warpctc training")
|
||||||
parser.add_argument("--dataset_path", type=str, default=None, help="Dataset, default is None.")
|
parser.add_argument("--dataset_path", type=str, default=None, help="Dataset, default is None.")
|
||||||
parser.add_argument("--checkpoint_path", type=str, default=None, help="checkpoint file path, default is None")
|
parser.add_argument("--checkpoint_path", type=str, default=None, help="checkpoint file path, default is None")
|
||||||
parser.add_argument('--platform', type=str, default='Ascend', choices=['Ascend', 'GPU'],
|
parser.add_argument('--platform', type=str, default='Ascend', choices=['Ascend', 'GPU', 'CPU'],
|
||||||
help='Running platform, choose from Ascend, GPU, and default is Ascend.')
|
help='Running platform, choose from Ascend, GPU or CPU, and default is Ascend.')
|
||||||
args_opt = parser.parse_args()
|
args_opt = parser.parse_args()
|
||||||
|
|
||||||
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.platform, save_graphs=False)
|
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.platform, save_graphs=False)
|
||||||
|
@ -54,8 +54,10 @@ if __name__ == '__main__':
|
||||||
batch_size=cf.batch_size)
|
batch_size=cf.batch_size)
|
||||||
if args_opt.platform == 'Ascend':
|
if args_opt.platform == 'Ascend':
|
||||||
net = StackedRNN(input_size=input_size, batch_size=cf.batch_size, hidden_size=cf.hidden_size)
|
net = StackedRNN(input_size=input_size, batch_size=cf.batch_size, hidden_size=cf.hidden_size)
|
||||||
else:
|
elif args_opt.platform == 'GPU':
|
||||||
net = StackedRNNForGPU(input_size=input_size, batch_size=cf.batch_size, hidden_size=cf.hidden_size)
|
net = StackedRNNForGPU(input_size=input_size, batch_size=cf.batch_size, hidden_size=cf.hidden_size)
|
||||||
|
else:
|
||||||
|
net = StackedRNNForCPU(input_size=input_size, batch_size=cf.batch_size, hidden_size=cf.hidden_size)
|
||||||
|
|
||||||
# load checkpoint
|
# load checkpoint
|
||||||
param_dict = load_checkpoint(args_opt.checkpoint_path)
|
param_dict = load_checkpoint(args_opt.checkpoint_path)
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
# Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
|
@ -14,18 +14,19 @@
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
"""export checkpoint file into air models"""
|
"""export checkpoint file into air models"""
|
||||||
import argparse
|
import argparse
|
||||||
|
import math as m
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from mindspore import Tensor, context, load_checkpoint, load_param_into_net, export
|
from mindspore import Tensor, context, load_checkpoint, load_param_into_net, export
|
||||||
|
|
||||||
from src.warpctc import StackedRNN
|
from src.warpctc import StackedRNN, StackedRNNForGPU, StackedRNNForCPU
|
||||||
from src.config import config
|
from src.config import config
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description="warpctc_export")
|
parser = argparse.ArgumentParser(description="warpctc_export")
|
||||||
parser.add_argument("--device_id", type=int, default=0, help="Device id")
|
parser.add_argument("--device_id", type=int, default=0, help="Device id")
|
||||||
parser.add_argument("--ckpt_file", type=str, required=True, help="warpctc ckpt file.")
|
parser.add_argument("--ckpt_file", type=str, required=True, help="warpctc ckpt file.")
|
||||||
parser.add_argument("--file_name", type=str, default="warpctc", help="warpctc output file name.")
|
parser.add_argument("--file_name", type=str, default="warpctc", help="warpctc output file name.")
|
||||||
parser.add_argument("--file_format", type=str, choices=["AIR", "ONNX", "MINDIR"], default="MINDIR", help="file format")
|
parser.add_argument("--file_format", type=str, choices=["AIR", "MINDIR"], default="MINDIR", help="file format")
|
||||||
parser.add_argument("--device_target", type=str, choices=["Ascend", "GPU", "CPU"], default="Ascend",
|
parser.add_argument("--device_target", type=str, choices=["Ascend", "GPU", "CPU"], default="Ascend",
|
||||||
help="device target")
|
help="device target")
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
@ -34,15 +35,24 @@ context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
|
||||||
if args.device_target == "Ascend":
|
if args.device_target == "Ascend":
|
||||||
context.set_context(device_id=args.device_id)
|
context.set_context(device_id=args.device_id)
|
||||||
|
|
||||||
|
if args.file_format == "AIR" and args.device_target != "Ascend":
|
||||||
|
raise ValueError("export AIR must on Ascend")
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
input_size = m.ceil(config.captcha_height / 64) * 64 * 3
|
||||||
captcha_width = config.captcha_width
|
captcha_width = config.captcha_width
|
||||||
captcha_height = config.captcha_height
|
captcha_height = config.captcha_height
|
||||||
batch_size = config.batch_size
|
batch_size = config.batch_size
|
||||||
hidden_size = config.hidden_size
|
hidden_size = config.hidden_size
|
||||||
net = StackedRNN(captcha_height * 3, batch_size, hidden_size)
|
image = Tensor(np.zeros([batch_size, 3, captcha_height, captcha_width], np.float32))
|
||||||
|
if args.device_target == 'Ascend':
|
||||||
|
net = StackedRNN(input_size=input_size, batch_size=batch_size, hidden_size=hidden_size)
|
||||||
|
image = Tensor(np.zeros([batch_size, 3, captcha_height, captcha_width], np.float16))
|
||||||
|
elif args.device_target == 'GPU':
|
||||||
|
net = StackedRNNForGPU(input_size=input_size, batch_size=batch_size, hidden_size=hidden_size)
|
||||||
|
else:
|
||||||
|
net = StackedRNNForCPU(input_size=input_size, batch_size=batch_size, hidden_size=hidden_size)
|
||||||
param_dict = load_checkpoint(args.ckpt_file)
|
param_dict = load_checkpoint(args.ckpt_file)
|
||||||
load_param_into_net(net, param_dict)
|
load_param_into_net(net, param_dict)
|
||||||
net.set_train(False)
|
net.set_train(False)
|
||||||
|
|
||||||
image = Tensor(np.zeros([batch_size, 3, captcha_height, captcha_width], np.float16))
|
|
||||||
export(net, image, file_name=args.file_name, file_format=args.file_format)
|
export(net, image, file_name=args.file_name, file_format=args.file_format)
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
#!/bin/bash
|
#!/bin/bash
|
||||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
# Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
|
@ -61,7 +61,7 @@ run_ascend() {
|
||||||
cd ..
|
cd ..
|
||||||
}
|
}
|
||||||
|
|
||||||
run_gpu() {
|
run_gpu_cpu() {
|
||||||
if [ -d "eval" ]; then
|
if [ -d "eval" ]; then
|
||||||
rm -rf ./eval
|
rm -rf ./eval
|
||||||
fi
|
fi
|
||||||
|
@ -70,15 +70,13 @@ run_gpu() {
|
||||||
cp -r ../src ./eval
|
cp -r ../src ./eval
|
||||||
cd ./eval || exit
|
cd ./eval || exit
|
||||||
env >env.log
|
env >env.log
|
||||||
python eval.py --dataset_path=$1 --checkpoint_path=$2 --platform=GPU > log.txt 2>&1 &
|
python eval.py --dataset_path=$1 --checkpoint_path=$2 --platform=$3 > log.txt 2>&1 &
|
||||||
cd ..
|
cd ..
|
||||||
}
|
}
|
||||||
|
|
||||||
if [ "Ascend" == $PLATFORM ]; then
|
if [ "Ascend" == $PLATFORM ]; then
|
||||||
run_ascend $PATH1 $PATH2
|
run_ascend $PATH1 $PATH2
|
||||||
elif [ "GPU" == $PLATFORM ]; then
|
|
||||||
run_gpu $PATH1 $PATH2
|
|
||||||
else
|
else
|
||||||
echo "error: PLATFORM=$PLATFORM is not support, only support Ascend and GPU."
|
run_gpu_cpu $PATH1 $PATH2 $PLATFORM
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
#!/bin/bash
|
#!/bin/bash
|
||||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
# Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
|
@ -48,9 +48,9 @@ run_ascend() {
|
||||||
cd ..
|
cd ..
|
||||||
}
|
}
|
||||||
|
|
||||||
run_gpu() {
|
run_gpu_cpu() {
|
||||||
env >env.log
|
env >env.log
|
||||||
python train.py --dataset_path=$1 --platform=GPU > log.txt 2>&1 &
|
python train.py --dataset_path=$1 --platform=$2 > log.txt 2>&1 &
|
||||||
cd ..
|
cd ..
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -64,8 +64,6 @@ cd ./train || exit
|
||||||
|
|
||||||
if [ "Ascend" == $PLATFORM ]; then
|
if [ "Ascend" == $PLATFORM ]; then
|
||||||
run_ascend $PATH1
|
run_ascend $PATH1
|
||||||
elif [ "GPU" == $PLATFORM ]; then
|
|
||||||
run_gpu $PATH1
|
|
||||||
else
|
else
|
||||||
echo "error: PLATFORM=$PLATFORM is not support, only support Ascend and GPU."
|
run_gpu_cpu $PATH1 $PLATFORM
|
||||||
fi
|
fi
|
|
@ -1,4 +1,4 @@
|
||||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
# Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
|
@ -131,3 +131,41 @@ class StackedRNNForGPU(nn.Cell):
|
||||||
res += (self.expand_dims(self.fc(output[i]), 0),)
|
res += (self.expand_dims(self.fc(output[i]), 0),)
|
||||||
res = self.concat(res)
|
res = self.concat(res)
|
||||||
return res
|
return res
|
||||||
|
|
||||||
|
class StackedRNNForCPU(nn.Cell):
|
||||||
|
"""
|
||||||
|
Define a stacked RNN network which contains two LSTM layers and one full-connect layer on CPU.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
input_size(int): Size of time sequence. Usually, the input_size is equal to three times of image height for
|
||||||
|
captcha images.
|
||||||
|
batch_size(int): batch size of input data, default is 64
|
||||||
|
hidden_size(int): the hidden size in LSTM layers, default is 512
|
||||||
|
num_classes(int): the number of classes.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, input_size, batch_size=64, hidden_size=512, num_classes=11):
|
||||||
|
super(StackedRNNForCPU, self).__init__()
|
||||||
|
self.batch_size = batch_size
|
||||||
|
self.input_size = input_size
|
||||||
|
k = (1 / hidden_size) ** 0.5
|
||||||
|
self.w1 = Parameter(
|
||||||
|
np.random.uniform(-k, k, (4 * hidden_size * (input_size + hidden_size + 1), 1, 1)).astype(np.float32))
|
||||||
|
self.w2 = Parameter(
|
||||||
|
np.random.uniform(-k, k, (4 * hidden_size * (2 * hidden_size + 1), 1, 1)).astype(np.float32))
|
||||||
|
self.h = Tensor(np.zeros(shape=(1, batch_size, hidden_size)).astype(np.float32))
|
||||||
|
self.c = Tensor(np.zeros(shape=(1, batch_size, hidden_size)).astype(np.float32))
|
||||||
|
|
||||||
|
self.lstm_1 = nn.LSTMCell(input_size=input_size, hidden_size=hidden_size)
|
||||||
|
self.lstm_2 = nn.LSTMCell(input_size=hidden_size, hidden_size=hidden_size)
|
||||||
|
|
||||||
|
self.fc = nn.Dense(in_channels=hidden_size, out_channels=num_classes)
|
||||||
|
self.transpose = P.Transpose()
|
||||||
|
|
||||||
|
def construct(self, x):
|
||||||
|
x = self.transpose(x, (3, 0, 2, 1))
|
||||||
|
x = F.reshape(x, (-1, self.batch_size, self.input_size))
|
||||||
|
y1, _, _, _, _ = self.lstm_1(x, self.h, self.c, self.w1)
|
||||||
|
y2, _, _, _, _ = self.lstm_2(y1, self.h, self.c, self.w2)
|
||||||
|
output = self.fc(y2) # y2 shape: [time_step, bs, hidden_size] output shape: [time_step, bs, num_classes].
|
||||||
|
return output
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
# Copyright 2020-2021 Huawei Technologies Co., Ltd
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
|
@ -28,7 +28,7 @@ from mindspore.communication.management import init, get_group_size, get_rank
|
||||||
from src.loss import CTCLoss
|
from src.loss import CTCLoss
|
||||||
from src.config import config as cf
|
from src.config import config as cf
|
||||||
from src.dataset import create_dataset
|
from src.dataset import create_dataset
|
||||||
from src.warpctc import StackedRNN, StackedRNNForGPU
|
from src.warpctc import StackedRNN, StackedRNNForGPU, StackedRNNForCPU
|
||||||
from src.warpctc_for_train import TrainOneStepCellWithGradClip
|
from src.warpctc_for_train import TrainOneStepCellWithGradClip
|
||||||
from src.lr_schedule import get_lr
|
from src.lr_schedule import get_lr
|
||||||
|
|
||||||
|
@ -37,8 +37,8 @@ set_seed(1)
|
||||||
parser = argparse.ArgumentParser(description="Warpctc training")
|
parser = argparse.ArgumentParser(description="Warpctc training")
|
||||||
parser.add_argument("--run_distribute", action='store_true', help="Run distribute, default is false.")
|
parser.add_argument("--run_distribute", action='store_true', help="Run distribute, default is false.")
|
||||||
parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path, default is None')
|
parser.add_argument('--dataset_path', type=str, default=None, help='Dataset path, default is None')
|
||||||
parser.add_argument('--platform', type=str, default='Ascend', choices=['Ascend', 'GPU'],
|
parser.add_argument('--platform', type=str, default='Ascend', choices=['Ascend', 'GPU', 'CPU'],
|
||||||
help='Running platform, choose from Ascend, GPU, and default is Ascend.')
|
help='Running platform, choose from Ascend, GPU or CPU, and default is Ascend.')
|
||||||
parser.set_defaults(run_distribute=False)
|
parser.set_defaults(run_distribute=False)
|
||||||
args_opt = parser.parse_args()
|
args_opt = parser.parse_args()
|
||||||
|
|
||||||
|
@ -80,8 +80,10 @@ if __name__ == '__main__':
|
||||||
batch_size=cf.batch_size)
|
batch_size=cf.batch_size)
|
||||||
if args_opt.platform == 'Ascend':
|
if args_opt.platform == 'Ascend':
|
||||||
net = StackedRNN(input_size=input_size, batch_size=cf.batch_size, hidden_size=cf.hidden_size)
|
net = StackedRNN(input_size=input_size, batch_size=cf.batch_size, hidden_size=cf.hidden_size)
|
||||||
else:
|
elif args_opt.platform == 'GPU':
|
||||||
net = StackedRNNForGPU(input_size=input_size, batch_size=cf.batch_size, hidden_size=cf.hidden_size)
|
net = StackedRNNForGPU(input_size=input_size, batch_size=cf.batch_size, hidden_size=cf.hidden_size)
|
||||||
|
else:
|
||||||
|
net = StackedRNNForCPU(input_size=input_size, batch_size=cf.batch_size, hidden_size=cf.hidden_size)
|
||||||
opt = nn.SGD(params=net.trainable_params(), learning_rate=lr, momentum=cf.momentum)
|
opt = nn.SGD(params=net.trainable_params(), learning_rate=lr, momentum=cf.momentum)
|
||||||
|
|
||||||
net = WithLossCell(net, loss)
|
net = WithLossCell(net, loss)
|
||||||
|
|
Loading…
Reference in New Issue