!18212 fcn8s support gpu

Merge pull request !18212 from 周莉莉/master
This commit is contained in:
i-robot 2021-06-16 09:15:31 +08:00 committed by Gitee
commit b71a20ed58
5 changed files with 290 additions and 36 deletions

119
model_zoo/official/cv/FCN8s/README.md Normal file → Executable file
View File

@ -66,6 +66,27 @@ Dataset used:
python eval.py --device_id device_id
```
- running on GPU with gpu default parameters
```python
# GPU单卡训练示例
python train.py \
--config_path=gpu_default_config.yaml \
--device_target=GPU
# GPU多卡训练示例
export RANK_SIZE=8
mpirun --allow-run-as-root -n $RANK_SIZE --output-filename log_output --merge-stderr-to-stdout \
python train.py \
--config_path=gpu_default_config.yaml \
--device_target=GPU
# GPU评估示例
python eval.py \
--config_path=gpu_default_config.yaml \
--device_target=GPU
```
# [脚本介绍](#contents)
## [脚本以及简单代码](#contents)
@ -79,6 +100,8 @@ Dataset used:
├── scripts
├── run_train.sh
├── run_standalone_train.sh
├── run_standalone_train_gpu.sh // train in gpu with single device
├── run_distribute_train_gpu.sh // train in gpu with multi device
├── run_eval.sh
├── run_infer_310.sh // Ascend推理shell脚本
├── build_data.sh
@ -97,7 +120,8 @@ Dataset used:
│ ├──device_adapter.py // getting device info
│ ├──local_adapter.py // getting device info
│ ├──moxing_adapter.py // Decorator
├── default_config.yaml // Parameters config
├── default_config.yaml // Ascend parameters config
├── gpu_default_config.yaml // GPU parameters config
├── train.py // training script
├── postprogress.py // 310推理后处理脚本
├── export.py // 将checkpoint文件导出到air/mindir
@ -138,7 +162,7 @@ Dataset used:
'ckpt_dir': './ckpt',
```
如需获取更多信息,请查看`default_config.yaml`.
如需获取更多信息,Ascend请查看`default_config.yaml`, GPU请查看`gpu_default_config.yaml`.
## [生成数据步骤](#contents)
@ -175,6 +199,31 @@ Dataset used:
sh scripts/run_train.sh [DEVICE_NUM] rank_table.json
```
- running on GPU with gpu default parameters
```python
# GPU单卡训练示例
python train.py \
--config_path=gpu_default_config.yaml \
--device_target=GPU
or
sh scripts/run_standalone_train_gpu.sh DEVICE_ID
# GPU八卡训练示例
export RANK_SIZE=8
mpirun --allow-run-as-root -n $RANK_SIZE --output-filename log_output --merge-stderr-to-stdout \
python train.py \
--config_path=gpu_default_config.yaml \
--device_target=GPU
or
sh run_distribute_train_gpu.sh [RANK_SIZE] [TRAIN_DATA_DIR]
# GPU评估示例
python eval.py \
--config_path=gpu_default_config.yaml \
--device_target=GPU
```
训练时训练过程中的epch和step以及此时的loss和精确度会呈现log.txt中:
```python
@ -261,10 +310,12 @@ Dataset used:
### 评估
- 在Ascend上使用PASCAL VOC 2012 验证集进行评估
- 在Ascend或GPU上使用PASCAL VOC 2012 验证集进行评估
在使用命令运行前请检查用于评估的checkpoint的路径。请设置路径为到checkpoint的绝对路径如 "/data/workspace/mindspore_dataset/FCN/FCN/model_new/FCN8s-500_82.ckpt"。
- eval on Ascend
```python
python eval.py
```
@ -273,7 +324,7 @@ Dataset used:
sh scripts/run_eval.sh DATA_ROOT DATA_LST CKPT_PATH
```
以上的python命令会在终端上运行你可以在终端上查看此次评估的结果。测试集的精确度会以如下方式呈现
以上的python命令会在终端上运行你可以在终端上查看此次评估的结果。测试集的精确度会以类似如下方式呈现:
```python
mean IoU 0.6467
@ -306,6 +357,20 @@ python export.py
mean IoU 0.0.64519877
```
- eval on GPU
```python
python eval.py \
--config_path=gpu_default_config.yaml \
--device_target=GPU
```
以上的python命令会在终端上运行你可以在终端上查看此次评估的结果。测试集的精确度会以类似如下方式呈现
```python
mean IoU 0.6472
```
# [模型介绍](#contents)
## [性能](#contents)
@ -314,35 +379,35 @@ python export.py
#### FCN8s on PASCAL VOC 2012
| Parameters | Ascend
| -------------------------- | -----------------------------------------------------------
| Model Version | FCN-8s
| Resource | Ascend 910; CPU 2.60GHz, 192cores; Memory 755G; OS Euler2.8
| uploaded Date | 12/30/2020 (month/day/year)
| MindSpore Version | 1.1.0-alpha
| Dataset | PASCAL VOC 2012 and SBD
| Training Parameters | epoch=500, steps=330, batch_size = 32, lr=0.015
| Optimizer | Momentum
| Loss Function | Softmax Cross Entropy
| outputs | probability
| Loss | 0.038
| Speed | 1pc: 564.652 ms/step;
| Parameters | Ascend | GPU |
| -------------------------- | ------------------------------------------------------------| -------------------------------------------------|
| Model Version | FCN-8s | FCN-8s |
| Resource | Ascend 910; CPU 2.60GHz, 192cores; Memory 755G; OS Euler2.8 | NV SMX2 V100-32G |
| uploaded Date | 12/30/2020 (month/day/year) | 06/11/2021 (month/day/year) |
| MindSpore Version | 1.1.0-alpha | 1.2.0 |
| Dataset | PASCAL VOC 2012 and SBD | PASCAL VOC 2012 and SBD |
| Training Parameters | epoch=500, steps=330, batch_size = 32, lr=0.015 | epoch=500, steps=330, batch_size = 8, lr=0.005 |
| Optimizer | Momentum | Momentum |
| Loss Function | Softmax Cross Entropy | Softmax Cross Entropy |
| outputs | probability | probability |
| Loss | 0.038 | 0.036 |
| Speed | 1pc: 564.652 ms/step; | 1pc: 455.460 ms/step; |
| Scripts | [FCN script](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/FCN8s)
### Inference Performance
#### FCN8s on PASCAL VOC
| Parameters | Ascend
| ------------------- | ---------------------------
| Model Version | FCN-8s
| Resource | Ascend 910; OS Euler2.8
| Uploaded Date | 10/29/2020 (month/day/year)
| MindSpore Version | 1.1.0-alpha
| Dataset | PASCAL VOC 2012
| batch_size | 16
| outputs | probability
| mean IoU | 64.67
| Parameters | Ascend | GPU
| ------------------- | --------------------------- | ---------------------------
| Model Version | FCN-8s | FCN-8s
| Resource | Ascend 910; OS Euler2.8 | NV SMX2 V100-32G
| Uploaded Date | 10/29/2020 (month/day/year) | 06/11/2021 (month/day/year)
| MindSpore Version | 1.1.0-alpha | 1.2.0
| Dataset | PASCAL VOC 2012 | PASCAL VOC 2012
| batch_size | 16 | 16
| outputs | probability | probability
| mean IoU | 64.67 | 64.72
## [如何使用](#contents)

View File

@ -0,0 +1,85 @@
# Builtin Configurations(DO NOT CHANGE THESE CONFIGURATIONS unlesee you know exactly what you are doing)
enable_modelarts: False
# url for modelarts
data_url: ""
train_url: ""
checkpoint_url: ""
# path for local
data_path: "/cache/data"
output_path: "/cache/train"
load_path: "/cache/checkpoint_path"
device_target: "GPU"
enable_profiling: False
checkpoint_path: "./checkpoint/"
checkpoint_file: "./checkpoint/.ckpt"
# ======================================================================================
# common options
crop_size: 512
image_mean: [103.53, 116.28, 123.675]
image_std: [57.375, 57.120, 58.395]
ignore_label: 255
num_classes: 21
model: "FCN8s"
# ======================================================================================
# Training options
train_batch_size: 8
min_scale: 0.5
max_scale: 2.0
data_file: "./vocaug_local_mindrecords/vocaug_local_mindrecords.mindrecords" # change to your own path of train data
# optimizer
train_epochs: 500
base_lr: 0.005
loss_scale: 1024
# model
ckpt_vgg16: "./vgg16_predtrain.ckpt" # change to your own path of backbone pretrain
ckpt_pre_trained: ""
save_steps: 330
keep_checkpoint_max: 5
ckpt_dir: "./ckpt"
# ======================================================================================
# Eval options
eval_batch_size: 16
data_root: "./VOCdevkit/VOC2012" # change to your own path of val data
data_lst: "./VOCdevkit/VOC2012/ImageSets/Segmentation/val.txt" # change to your own path of val data list
scales: [1.0]
flip: False
freeze_bn: False
ckpt_file: "./FCN8s_1-500_220.ckpt" # change to your own path of evaluate model
---
# Help description for each configuration
enable_modelarts: "Whether training on modelarts default: False"
data_url: "Url for modelarts"
train_url: "Url for modelarts"
data_path: "The location of input data"
output_pah: "The location of the output file"
device_target: "device id of GPU or Ascend. (Default: None)"
enable_profiling: "Whether enable profiling while training default: False"
crop_size: "crop_size"
image_mean: "image_mean"
image_std: "image std"
ignore_label: "ignore label"
num_classes: "number of classes"
model: "select model"
data_file: "path of train data"
train_batch_size: "train_batch_size"
min_scale: "min scales of train"
max_scale: "max scales of train"
train_epochs: "train epoch"
base_lr: "base lr"
loss_scale: "loss scales"
ckpt_vgg16: "backbone pretrain"
ckpt_pre_trained: "model pretrain"
data_root: "root path of val data"
eval_batch_size: "eval batch size"
data_lst: "list of val data"
scales: "scales of evaluation"
flip: "freeze bn"
ckpt_file: "model to evaluate"

View File

@ -0,0 +1,55 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 2 ]; then
echo "Usage: sh run_distribute_train_gpu.sh [RANK_SIZE] [TRAIN_DATA_DIR]"
exit 1
fi
get_real_path() {
if [ "${1:0:1}" == "/" ]; then
echo "$1"
else
echo "$(realpath -m $PWD/$1)"
fi
}
export RANK_SIZE=$1
PROJECT_DIR=$(cd ./"`dirname $0`" || exit; pwd)
TRAIN_DATA_DIR=$(get_real_path $2)
if [ ! -d $TRAIN_DATA_DIR ]; then
echo "error: TRAIN_DATA_DIR=$TRAIN_DATA_DIR is not a directory"
exit 1
fi
if [ -d "distribute_train" ]; then
rm -rf ./distribute_train
fi
mkdir ./distribute_train
cp ./*.py ./distribute_train
cp ./*.yaml ./distribute_train
cp -r ./src ./distribute_train
cd ./distribute_train || exit
CONFIG_FILE="$PROJECT_DIR/../gpu_default_config.yaml"
mpirun --allow-run-as-root -n $RANK_SIZE --output-filename log_output --merge-stderr-to-stdout \
nohup python train.py \
--config_path=$CONFIG_FILE \
--device_target=GPU > log.txt 2>&1 &
cd ..

View File

@ -0,0 +1,44 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
if [ $# != 1 ]
then
echo "Usage: sh scripts/run_standalone_train_gpu.sh DEVICE_ID"
exit 1
fi
export DEVICE_ID=$1
PROJECT_DIR=$(cd ./"`dirname $0`" || exit; pwd)
train_path=train_standalone${DEVICE_ID}
if [ -d ${train_path} ]; then
rm -rf ${train_path}
fi
mkdir -p ${train_path}
cp -r ./src ${train_path}
cp ./train.py ${train_path}
cp ./*.yaml ${train_path}
echo "start training for device $DEVICE_ID"
cd ${train_path}|| exit
CONFIG_FILE="$PROJECT_DIR/../gpu_default_config.yaml"
nohup python train.py \
--config_path=$CONFIG_FILE \
--device_target=GPU > log 2>&1 &
cd ..

View File

@ -21,7 +21,7 @@ from mindspore.context import ParallelMode
import mindspore.nn as nn
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.communication.management import init
from mindspore.communication.management import init, get_rank, get_group_size
from mindspore.train.callback import LossMonitor, TimeMonitor
from mindspore.train.loss_scale_manager import FixedLossScaleManager
from mindspore.common import set_seed
@ -31,7 +31,7 @@ from src.utils.lr_scheduler import CosineAnnealingLR
from src.nets.FCN8s import FCN8s
from src.model_utils.config import config
from src.model_utils.moxing_adapter import moxing_wrapper
from src.model_utils.device_adapter import get_device_id, get_device_num, get_rank_id
from src.model_utils.device_adapter import get_device_id, get_device_num
set_seed(1)
@ -45,7 +45,7 @@ def modelarts_pre_process():
def train():
device_num = get_device_num()
context.set_context(mode=context.GRAPH_MODE, enable_auto_mixed_precision=True, save_graphs=False,
device_target='Ascend', device_id=get_device_id())
device_target=config.device_target, device_id=get_device_id())
# init multicards training
config.rank = 0
config.group_size = 1
@ -53,8 +53,8 @@ def train():
parallel_mode = ParallelMode.DATA_PARALLEL
context.set_auto_parallel_context(parallel_mode=parallel_mode, gradients_mean=True, device_num=device_num)
init()
config.rank = get_rank_id()
config.group_size = get_device_num()
config.rank = get_rank()
config.group_size = get_group_size()
# dataset
dataset = data_generator.SegDataset(image_mean=config.image_mean,
@ -112,10 +112,15 @@ def train():
# loss scale
manager_loss_scale = FixedLossScaleManager(config.loss_scale, drop_overflow_update=False)
optimizer = nn.Momentum(params=net.trainable_params(), learning_rate=lr, momentum=0.9, weight_decay=0.0001,
loss_scale=config.loss_scale)
model = Model(net, loss_fn=loss_, loss_scale_manager=manager_loss_scale, optimizer=optimizer, amp_level="O3")
if config.device_target == "Ascend":
optimizer = nn.Momentum(params=net.trainable_params(), learning_rate=lr, momentum=0.9, weight_decay=0.0001,
loss_scale=config.loss_scale)
model = Model(net, loss_fn=loss_, loss_scale_manager=manager_loss_scale, optimizer=optimizer, amp_level="O3")
elif config.device_target == "GPU":
optimizer = nn.Momentum(params=net.trainable_params(), learning_rate=lr, momentum=0.9, weight_decay=0.0001)
model = Model(net, loss_fn=loss_, optimizer=optimizer)
else:
raise ValueError("Unsupported platform.")
# callback for saving ckpts
time_cb = TimeMonitor(data_size=iters_per_epoch)