Implemented GPU training for psenet

This commit is contained in:
Wei Sun 2021-07-26 15:55:01 -04:00
parent 59322811fd
commit d3416e7dcf
6 changed files with 187 additions and 27 deletions

View File

@ -11,7 +11,8 @@
- [Script Parameters](#script-parameters)
- [Training Process](#training-process)
- [Training](#training)
- [Distributed Training](#distributed-training)
- [Distributed Ascend Training](#distributed-ascend-training)
- [Distributed GPU Training](#distributed-gpu-training)
- [Evaluation Process](#evaluation-process)
- [Evaluation](#evaluation)
- [Inference Process](#inference-process)
@ -49,8 +50,8 @@ A testing set containing about 2000 readable words
# [Environment Requirements](#contents)
- HardwareAscend
- Prepare hardware environment with Ascend processor.
- HardwareAscend or GPU
- Prepare hardware environment with Ascend processor or GPU.
- Framework
- [MindSpore](http://www.mindspore.cn/install/en)
- For more information, please check the resources below
@ -101,8 +102,10 @@ sh scripts/run_eval_ascend.sh
├── README_CN.md // descriptions about PSENet in Chinese
├── README.md // descriptions about PSENet in English
├── scripts
├── run_distribute_train.sh // shell script for distributed
└── run_eval_ascend.sh // shell script for evaluation
├── run_distribute_train.sh // shell script for distributed ascend
├── run_distribute_train_gpu.sh // shell script for distributed gpu
├── run_eval_ascend.sh // shell script for evaluation ascend
├── run_eval_gpu.sh // shell script for evaluation gpu
├── ascend310_infer // application for 310 inference
├── src
├── model_utils
@ -136,6 +139,7 @@ sh scripts/run_eval_ascend.sh
```default_config.yaml
Major parameters in default_config.yaml are:
--device_target: Ascend or GPU
--pre_trained: Whether training from scratch or training based on the
pre-trained model.Optional values are True, False.
--device_id: Device ID used to train or evaluate the dataset. Ignore it
@ -145,9 +149,9 @@ Major parameters in default_config.yaml are:
## [Training Process](#contents)
### Distributed Training
### Distributed Ascend Training
For distributed training, a hccl configuration file with JSON format needs to be created in advance.
For distributed ascend training, a hccl configuration file with JSON format needs to be created in advance.
Please follow the instructions in the link below: <https://gitee.com/mindspore/mindspore/tree/master/model_zoo/utils/hccl_tools>.
@ -169,6 +173,24 @@ device_1/log:epcoh: 2, step: 40, loss is 0.76629
```
### Distributed GPU Training
```shell
sh scripts/run_distribute_train_gpu.sh [PRED_TRAINED PATH] [TRAIN_ROOT_DIR]
```
After training begins, log and loss.log file will be in train_parallel directory.
```log
# cat train_parallel/loss.log
time: 2021-07-24 02:08:33, epoch: 10, step: 31, loss is 0.68408
time: 2021-07-24 02:08:33, epoch: 10, step: 31, loss is 0.67984
...
time: 2021-07-24 04:01:07, epoch: 90, step: 31, loss is 0.61662
time: 2021-07-24 04:01:07, epoch: 90, step: 31, loss is 0.58495
```
## [Evaluation Process](#contents)
### run test code
@ -245,9 +267,11 @@ step 2: click "My Methods" button,then download Evaluation Scripts.
step 3: it is recommended to symlink the eval method root to $MINDSPORE/model_zoo/psenet/eval_ic15/. if your folder structure is different,you may need to change the corresponding paths in eval script files.
```shell
sh ./script/run_eval_ascend.sh.sh
sh ./script/run_eval_ascend.sh
```
The two scripts ./script/run_eval_ascend.sh and ./script/run_eval_gpu.sh are the same, you may run either for evaluating on ICDAR2015.
#### Result
Calculated!{"precision": 0.814796668299853, "recall": 0.8006740491092923, "hmean": 0.8076736279747451, "AP": 0}
@ -324,6 +348,24 @@ The `res` folder is generated in the upper-level directory. For details about th
| Checkpoint for Fine tuning | 109.44M (.ckpt file) |
| Scripts | <https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/psenet> |
| Parameters | GPU |
| -------------------------- | ----------------------------------------------------------- |
| Model Version | PSENet |
| Resource | GPU(Tesla V100-PCIE); CPU 2.60 GHz, 26 cores; Memory 790G; OS Euler2.0 |
| uploaded Date | 07/24/2021 (month/day/year) |
| MindSpore Version | 1.3.0 |
| Dataset | ICDAR2015 |
| Training Parameters | start_lr=0.1; lr_scale=0.1 |
| Optimizer | SGD |
| Loss Function | LossCallBack |
| outputs | probability |
| Loss | 0.40 |
| Speed | 1pc: 2726 ms/step; 8pcs: 2726 ms/step |
| Total time | 1pc: 335.6 h; 8pcs: 41.95 h |
| Parameters (M) | 27.36 |
| Checkpoint for Fine tuning | 109.44M (.ckpt file) |
| Scripts | <https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/psenet> |
### Inference Performance
| Parameters | Ascend |

View File

@ -1,5 +1,5 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2020-2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -38,10 +38,17 @@ then
exit 1
fi
PATH2=$(get_real_path $2)
if [ ! -f $PATH2 ]
if [ ! -z "${2}" ];
then
echo "error: PRETRAINED_PATH=$PATH2 is not a file"
PATH2=$(get_real_path $2)
else
PATH2=$2
fi
PATH3=$(get_real_path $3)
if [ ! -d $PATH3 ]
then
echo "error: TRAIN_ROOT_DIR=$PATH3 is not a directory"
exit 1
fi
@ -74,6 +81,6 @@ do
cd ${current_exec_path}/device_$i || exit
export RANK_ID=$i
export DEVICE_ID=$i
python ${current_exec_path}/train.py --run_distribute=True --pre_trained $PATH2 --TRAIN_ROOT_DIR=$3 >test_deep$i.log 2>&1 &
python ${current_exec_path}/train.py --run_distribute=True --pre_trained=$PATH2 --TRAIN_ROOT_DIR=$PATH3 >test_deep$i.log 2>&1 &
cd ${current_exec_path} || exit
done

View File

@ -0,0 +1,69 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 1 ] && [ $# != 2 ]
then
echo "Usage: sh run_distribute_train.sh [PRETRAINED_PATH] [TRAIN_ROOT_DIR]"
exit 1
fi
get_real_path(){
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
if [ ! -z "${1}" ];
then
PATH1=$(get_real_path $1)
else
PATH1=$1
fi
PATH2=$(get_real_path $2)
if [ ! -d $PATH2 ]
then
echo "error: TRAIN_ROOT_DIR=$PATH2 is not a directory"
exit 1
fi
ulimit -u unlimited
export DEVICE_NUM=8
export RANK_SIZE=8
if [ -d "train_parallel" ];
then
rm -rf ./train_parallel
fi
mkdir ./train_parallel
cp ./*.py ./train_parallel
cp ./scripts/*.sh ./train_parallel
cp -r ./src ./train_parallel
cp ./*yaml ./train_parallel
cd ./train_parallel || exit
env > env.log
if [ -f $PATH1 ]
then
mpirun --allow-run-as-root -n $RANK_SIZE --output-filename log_output --merge-stderr-to-stdout \
python train.py --run_distribute=True --device_target="GPU" --pre_trained=$PATH1 --TRAIN_ROOT_DIR=$PATH2 > log 2>&1 &
else
mpirun --allow-run-as-root -n $RANK_SIZE --output-filename log_output --merge-stderr-to-stdout \
python train.py --device_target="GPU" --run_distribute=True --TRAIN_ROOT_DIR=$PATH2 > log 2>&1 &
fi
cd .. || exit

View File

@ -1,5 +1,5 @@
#!/bin/bash
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2020-2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.

View File

@ -0,0 +1,25 @@
#!/bin/bash
# Copyright 2020-2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
current_exec_path=$(pwd)
res_path=${current_exec_path}/res/submit_ic15/
eval_tool_path=${current_exec_path}/eval_ic15/
cd ${res_path} || exit
zip ${eval_tool_path}/submit.zip ./*
cd ${eval_tool_path} || exit
python ./script.py -s=submit.zip -g=gt.zip
cd ${current_exec_path} || exit

View File

@ -1,4 +1,4 @@
# Copyright 2020 Huawei Technologies Co., Ltd
# Copyright 2020-2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -18,7 +18,7 @@ import ast
import operator
import mindspore.nn as nn
from mindspore import context
from mindspore.communication.management import init
from mindspore.communication.management import init, get_rank
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMonitor
from mindspore.train.model import Model
from mindspore.context import ParallelMode
@ -35,7 +35,6 @@ from src.model_utils.device_adapter import get_device_id, get_device_num, get_ra
set_seed(1)
context.set_context(mode=context.GRAPH_MODE, device_target="Ascend", device_id=get_device_id())
binOps = {
@ -70,16 +69,25 @@ def modelarts_pre_process():
@moxing_wrapper(pre_process=modelarts_pre_process)
def train():
device_target = config.device_target
context.set_context(mode=context.GRAPH_MODE,
device_target=device_target,
device_id=get_device_id())
rank_id = 0
config.BASE_LR = arithmeticeval(config.BASE_LR)
config.WARMUP_RATIO = arithmeticeval(config.WARMUP_RATIO)
device_num = get_device_num()
if config.run_distribute:
context.set_auto_parallel_context(device_num=device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
context.set_auto_parallel_context(device_num=device_num,
parallel_mode=ParallelMode.DATA_PARALLEL,
gradients_mean=True)
init()
rank_id = get_rank_id()
if device_target == 'Ascend':
rank_id = get_rank_id()
else:
rank_id = get_rank()
# dataset/network/criterion/optim
ds = train_dataset_creator(rank_id, device_num)
@ -89,14 +97,18 @@ def train():
config.INFERENCE = False
net = ETSNet(config)
net = net.set_train()
param_dict = load_checkpoint(config.pre_trained)
load_param_into_net(net, param_dict)
print('Load Pretrained parameters done!')
if config.pre_trained:
param_dict = load_checkpoint(config.pre_trained)
load_param_into_net(net, param_dict)
print('Load Pretrained parameters done!')
criterion = DiceLoss(batch_size=config.TRAIN_BATCH_SIZE)
lrs = dynamic_lr(config.BASE_LR, config.TRAIN_TOTAL_ITER, config.WARMUP_STEP, config.WARMUP_RATIO)
opt = nn.SGD(params=net.trainable_params(), learning_rate=lrs, momentum=0.99, weight_decay=5e-4)
lrs = dynamic_lr(config.BASE_LR, config.TRAIN_TOTAL_ITER,
config.WARMUP_STEP, config.WARMUP_RATIO)
opt = nn.SGD(params=net.trainable_params(), learning_rate=lrs,
momentum=0.99, weight_decay=5e-4)
# warp model
net = WithLossCell(net, criterion)
@ -109,11 +121,16 @@ def train():
loss_cb = LossCallBack(per_print_times=10)
# set and apply parameters of check point config.TRAIN_MODEL_SAVE_PATH
ckpoint_cf = CheckpointConfig(save_checkpoint_steps=1875, keep_checkpoint_max=3)
ckpoint_cb = ModelCheckpoint(prefix="ETSNet", config=ckpoint_cf,
directory="{}/ckpt_{}".format(config.TRAIN_MODEL_SAVE_PATH, rank_id))
ckpoint_cb = ModelCheckpoint(prefix="ETSNet",
config=ckpoint_cf,
directory="{}/ckpt_{}".format(config.TRAIN_MODEL_SAVE_PATH,
rank_id))
model = Model(net)
model.train(config.TRAIN_REPEAT_NUM, ds, dataset_sink_mode=True, callbacks=[time_cb, loss_cb, ckpoint_cb])
model.train(config.TRAIN_REPEAT_NUM,
ds,
dataset_sink_mode=True,
callbacks=[time_cb, loss_cb, ckpoint_cb])
if __name__ == '__main__':