forked from mindspore-Ecosystem/mindspore
commit
b71a20ed58
|
@ -66,6 +66,27 @@ Dataset used:
|
|||
python eval.py --device_id device_id
|
||||
```
|
||||
|
||||
- running on GPU with gpu default parameters
|
||||
|
||||
```python
|
||||
# GPU单卡训练示例
|
||||
python train.py \
|
||||
--config_path=gpu_default_config.yaml \
|
||||
--device_target=GPU
|
||||
|
||||
# GPU多卡训练示例
|
||||
export RANK_SIZE=8
|
||||
mpirun --allow-run-as-root -n $RANK_SIZE --output-filename log_output --merge-stderr-to-stdout \
|
||||
python train.py \
|
||||
--config_path=gpu_default_config.yaml \
|
||||
--device_target=GPU
|
||||
|
||||
# GPU评估示例
|
||||
python eval.py \
|
||||
--config_path=gpu_default_config.yaml \
|
||||
--device_target=GPU
|
||||
```
|
||||
|
||||
# [脚本介绍](#contents)
|
||||
|
||||
## [脚本以及简单代码](#contents)
|
||||
|
@ -79,6 +100,8 @@ Dataset used:
|
|||
├── scripts
|
||||
├── run_train.sh
|
||||
├── run_standalone_train.sh
|
||||
├── run_standalone_train_gpu.sh // train in gpu with single device
|
||||
├── run_distribute_train_gpu.sh // train in gpu with multi device
|
||||
├── run_eval.sh
|
||||
├── run_infer_310.sh // Ascend推理shell脚本
|
||||
├── build_data.sh
|
||||
|
@ -97,7 +120,8 @@ Dataset used:
|
|||
│ ├──device_adapter.py // getting device info
|
||||
│ ├──local_adapter.py // getting device info
|
||||
│ ├──moxing_adapter.py // Decorator
|
||||
├── default_config.yaml // Parameters config
|
||||
├── default_config.yaml // Ascend parameters config
|
||||
├── gpu_default_config.yaml // GPU parameters config
|
||||
├── train.py // training script
|
||||
├── postprogress.py // 310推理后处理脚本
|
||||
├── export.py // 将checkpoint文件导出到air/mindir
|
||||
|
@ -138,7 +162,7 @@ Dataset used:
|
|||
'ckpt_dir': './ckpt',
|
||||
```
|
||||
|
||||
如需获取更多信息,请查看`default_config.yaml`.
|
||||
如需获取更多信息,Ascend请查看`default_config.yaml`, GPU请查看`gpu_default_config.yaml`.
|
||||
|
||||
## [生成数据步骤](#contents)
|
||||
|
||||
|
@ -175,6 +199,31 @@ Dataset used:
|
|||
sh scripts/run_train.sh [DEVICE_NUM] rank_table.json
|
||||
```
|
||||
|
||||
- running on GPU with gpu default parameters
|
||||
|
||||
```python
|
||||
# GPU单卡训练示例
|
||||
python train.py \
|
||||
--config_path=gpu_default_config.yaml \
|
||||
--device_target=GPU
|
||||
or
|
||||
sh scripts/run_standalone_train_gpu.sh DEVICE_ID
|
||||
|
||||
# GPU八卡训练示例
|
||||
export RANK_SIZE=8
|
||||
mpirun --allow-run-as-root -n $RANK_SIZE --output-filename log_output --merge-stderr-to-stdout \
|
||||
python train.py \
|
||||
--config_path=gpu_default_config.yaml \
|
||||
--device_target=GPU
|
||||
or
|
||||
sh run_distribute_train_gpu.sh [RANK_SIZE] [TRAIN_DATA_DIR]
|
||||
|
||||
# GPU评估示例
|
||||
python eval.py \
|
||||
--config_path=gpu_default_config.yaml \
|
||||
--device_target=GPU
|
||||
```
|
||||
|
||||
训练时,训练过程中的epch和step以及此时的loss和精确度会呈现log.txt中:
|
||||
|
||||
```python
|
||||
|
@ -261,10 +310,12 @@ Dataset used:
|
|||
|
||||
### 评估
|
||||
|
||||
- 在Ascend上使用PASCAL VOC 2012 验证集进行评估
|
||||
- 在Ascend或GPU上使用PASCAL VOC 2012 验证集进行评估
|
||||
|
||||
在使用命令运行前,请检查用于评估的checkpoint的路径。请设置路径为到checkpoint的绝对路径,如 "/data/workspace/mindspore_dataset/FCN/FCN/model_new/FCN8s-500_82.ckpt"。
|
||||
|
||||
- eval on Ascend
|
||||
|
||||
```python
|
||||
python eval.py
|
||||
```
|
||||
|
@ -273,7 +324,7 @@ Dataset used:
|
|||
sh scripts/run_eval.sh DATA_ROOT DATA_LST CKPT_PATH
|
||||
```
|
||||
|
||||
以上的python命令会在终端上运行,你可以在终端上查看此次评估的结果。测试集的精确度会以如下方式呈现:
|
||||
以上的python命令会在终端上运行,你可以在终端上查看此次评估的结果。测试集的精确度会以类似如下方式呈现:
|
||||
|
||||
```python
|
||||
mean IoU 0.6467
|
||||
|
@ -306,6 +357,20 @@ python export.py
|
|||
mean IoU 0.0.64519877
|
||||
```
|
||||
|
||||
- eval on GPU
|
||||
|
||||
```python
|
||||
python eval.py \
|
||||
--config_path=gpu_default_config.yaml \
|
||||
--device_target=GPU
|
||||
```
|
||||
|
||||
以上的python命令会在终端上运行,你可以在终端上查看此次评估的结果。测试集的精确度会以类似如下方式呈现:
|
||||
|
||||
```python
|
||||
mean IoU 0.6472
|
||||
```
|
||||
|
||||
# [模型介绍](#contents)
|
||||
|
||||
## [性能](#contents)
|
||||
|
@ -314,35 +379,35 @@ python export.py
|
|||
|
||||
#### FCN8s on PASCAL VOC 2012
|
||||
|
||||
| Parameters | Ascend
|
||||
| -------------------------- | -----------------------------------------------------------
|
||||
| Model Version | FCN-8s
|
||||
| Resource | Ascend 910; CPU 2.60GHz, 192cores; Memory 755G; OS Euler2.8
|
||||
| uploaded Date | 12/30/2020 (month/day/year)
|
||||
| MindSpore Version | 1.1.0-alpha
|
||||
| Dataset | PASCAL VOC 2012 and SBD
|
||||
| Training Parameters | epoch=500, steps=330, batch_size = 32, lr=0.015
|
||||
| Optimizer | Momentum
|
||||
| Loss Function | Softmax Cross Entropy
|
||||
| outputs | probability
|
||||
| Loss | 0.038
|
||||
| Speed | 1pc: 564.652 ms/step;
|
||||
| Parameters | Ascend | GPU |
|
||||
| -------------------------- | ------------------------------------------------------------| -------------------------------------------------|
|
||||
| Model Version | FCN-8s | FCN-8s |
|
||||
| Resource | Ascend 910; CPU 2.60GHz, 192cores; Memory 755G; OS Euler2.8 | NV SMX2 V100-32G |
|
||||
| uploaded Date | 12/30/2020 (month/day/year) | 06/11/2021 (month/day/year) |
|
||||
| MindSpore Version | 1.1.0-alpha | 1.2.0 |
|
||||
| Dataset | PASCAL VOC 2012 and SBD | PASCAL VOC 2012 and SBD |
|
||||
| Training Parameters | epoch=500, steps=330, batch_size = 32, lr=0.015 | epoch=500, steps=330, batch_size = 8, lr=0.005 |
|
||||
| Optimizer | Momentum | Momentum |
|
||||
| Loss Function | Softmax Cross Entropy | Softmax Cross Entropy |
|
||||
| outputs | probability | probability |
|
||||
| Loss | 0.038 | 0.036 |
|
||||
| Speed | 1pc: 564.652 ms/step; | 1pc: 455.460 ms/step; |
|
||||
| Scripts | [FCN script](https://gitee.com/mindspore/mindspore/tree/master/model_zoo/official/cv/FCN8s)
|
||||
|
||||
### Inference Performance
|
||||
|
||||
#### FCN8s on PASCAL VOC
|
||||
|
||||
| Parameters | Ascend
|
||||
| ------------------- | ---------------------------
|
||||
| Model Version | FCN-8s
|
||||
| Resource | Ascend 910; OS Euler2.8
|
||||
| Uploaded Date | 10/29/2020 (month/day/year)
|
||||
| MindSpore Version | 1.1.0-alpha
|
||||
| Dataset | PASCAL VOC 2012
|
||||
| batch_size | 16
|
||||
| outputs | probability
|
||||
| mean IoU | 64.67
|
||||
| Parameters | Ascend | GPU
|
||||
| ------------------- | --------------------------- | ---------------------------
|
||||
| Model Version | FCN-8s | FCN-8s
|
||||
| Resource | Ascend 910; OS Euler2.8 | NV SMX2 V100-32G
|
||||
| Uploaded Date | 10/29/2020 (month/day/year) | 06/11/2021 (month/day/year)
|
||||
| MindSpore Version | 1.1.0-alpha | 1.2.0
|
||||
| Dataset | PASCAL VOC 2012 | PASCAL VOC 2012
|
||||
| batch_size | 16 | 16
|
||||
| outputs | probability | probability
|
||||
| mean IoU | 64.67 | 64.72
|
||||
|
||||
## [如何使用](#contents)
|
||||
|
||||
|
|
|
@ -0,0 +1,85 @@
|
|||
# Builtin Configurations(DO NOT CHANGE THESE CONFIGURATIONS unlesee you know exactly what you are doing)
|
||||
enable_modelarts: False
|
||||
# url for modelarts
|
||||
data_url: ""
|
||||
train_url: ""
|
||||
checkpoint_url: ""
|
||||
# path for local
|
||||
data_path: "/cache/data"
|
||||
output_path: "/cache/train"
|
||||
load_path: "/cache/checkpoint_path"
|
||||
device_target: "GPU"
|
||||
enable_profiling: False
|
||||
checkpoint_path: "./checkpoint/"
|
||||
checkpoint_file: "./checkpoint/.ckpt"
|
||||
# ======================================================================================
|
||||
# common options
|
||||
|
||||
crop_size: 512
|
||||
image_mean: [103.53, 116.28, 123.675]
|
||||
image_std: [57.375, 57.120, 58.395]
|
||||
ignore_label: 255
|
||||
num_classes: 21
|
||||
model: "FCN8s"
|
||||
|
||||
# ======================================================================================
|
||||
# Training options
|
||||
train_batch_size: 8
|
||||
min_scale: 0.5
|
||||
max_scale: 2.0
|
||||
data_file: "./vocaug_local_mindrecords/vocaug_local_mindrecords.mindrecords" # change to your own path of train data
|
||||
|
||||
# optimizer
|
||||
train_epochs: 500
|
||||
base_lr: 0.005
|
||||
loss_scale: 1024
|
||||
|
||||
# model
|
||||
ckpt_vgg16: "./vgg16_predtrain.ckpt" # change to your own path of backbone pretrain
|
||||
ckpt_pre_trained: ""
|
||||
save_steps: 330
|
||||
keep_checkpoint_max: 5
|
||||
ckpt_dir: "./ckpt"
|
||||
|
||||
|
||||
# ======================================================================================
|
||||
# Eval options
|
||||
eval_batch_size: 16
|
||||
data_root: "./VOCdevkit/VOC2012" # change to your own path of val data
|
||||
data_lst: "./VOCdevkit/VOC2012/ImageSets/Segmentation/val.txt" # change to your own path of val data list
|
||||
scales: [1.0]
|
||||
flip: False
|
||||
freeze_bn: False
|
||||
ckpt_file: "./FCN8s_1-500_220.ckpt" # change to your own path of evaluate model
|
||||
|
||||
|
||||
---
|
||||
# Help description for each configuration
|
||||
enable_modelarts: "Whether training on modelarts default: False"
|
||||
data_url: "Url for modelarts"
|
||||
train_url: "Url for modelarts"
|
||||
data_path: "The location of input data"
|
||||
output_pah: "The location of the output file"
|
||||
device_target: "device id of GPU or Ascend. (Default: None)"
|
||||
enable_profiling: "Whether enable profiling while training default: False"
|
||||
crop_size: "crop_size"
|
||||
image_mean: "image_mean"
|
||||
image_std: "image std"
|
||||
ignore_label: "ignore label"
|
||||
num_classes: "number of classes"
|
||||
model: "select model"
|
||||
data_file: "path of train data"
|
||||
train_batch_size: "train_batch_size"
|
||||
min_scale: "min scales of train"
|
||||
max_scale: "max scales of train"
|
||||
train_epochs: "train epoch"
|
||||
base_lr: "base lr"
|
||||
loss_scale: "loss scales"
|
||||
ckpt_vgg16: "backbone pretrain"
|
||||
ckpt_pre_trained: "model pretrain"
|
||||
data_root: "root path of val data"
|
||||
eval_batch_size: "eval batch size"
|
||||
data_lst: "list of val data"
|
||||
scales: "scales of evaluation"
|
||||
flip: "freeze bn"
|
||||
ckpt_file: "model to evaluate"
|
|
@ -0,0 +1,55 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
if [ $# != 2 ]; then
|
||||
echo "Usage: sh run_distribute_train_gpu.sh [RANK_SIZE] [TRAIN_DATA_DIR]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
get_real_path() {
|
||||
if [ "${1:0:1}" == "/" ]; then
|
||||
echo "$1"
|
||||
else
|
||||
echo "$(realpath -m $PWD/$1)"
|
||||
fi
|
||||
}
|
||||
|
||||
export RANK_SIZE=$1
|
||||
PROJECT_DIR=$(cd ./"`dirname $0`" || exit; pwd)
|
||||
TRAIN_DATA_DIR=$(get_real_path $2)
|
||||
|
||||
if [ ! -d $TRAIN_DATA_DIR ]; then
|
||||
echo "error: TRAIN_DATA_DIR=$TRAIN_DATA_DIR is not a directory"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ -d "distribute_train" ]; then
|
||||
rm -rf ./distribute_train
|
||||
fi
|
||||
|
||||
mkdir ./distribute_train
|
||||
cp ./*.py ./distribute_train
|
||||
cp ./*.yaml ./distribute_train
|
||||
cp -r ./src ./distribute_train
|
||||
cd ./distribute_train || exit
|
||||
|
||||
CONFIG_FILE="$PROJECT_DIR/../gpu_default_config.yaml"
|
||||
|
||||
mpirun --allow-run-as-root -n $RANK_SIZE --output-filename log_output --merge-stderr-to-stdout \
|
||||
nohup python train.py \
|
||||
--config_path=$CONFIG_FILE \
|
||||
--device_target=GPU > log.txt 2>&1 &
|
||||
cd ..
|
|
@ -0,0 +1,44 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
if [ $# != 1 ]
|
||||
then
|
||||
echo "Usage: sh scripts/run_standalone_train_gpu.sh DEVICE_ID"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
export DEVICE_ID=$1
|
||||
PROJECT_DIR=$(cd ./"`dirname $0`" || exit; pwd)
|
||||
train_path=train_standalone${DEVICE_ID}
|
||||
|
||||
if [ -d ${train_path} ]; then
|
||||
rm -rf ${train_path}
|
||||
fi
|
||||
mkdir -p ${train_path}
|
||||
cp -r ./src ${train_path}
|
||||
cp ./train.py ${train_path}
|
||||
cp ./*.yaml ${train_path}
|
||||
|
||||
echo "start training for device $DEVICE_ID"
|
||||
|
||||
cd ${train_path}|| exit
|
||||
|
||||
CONFIG_FILE="$PROJECT_DIR/../gpu_default_config.yaml"
|
||||
|
||||
nohup python train.py \
|
||||
--config_path=$CONFIG_FILE \
|
||||
--device_target=GPU > log 2>&1 &
|
||||
cd ..
|
|
@ -21,7 +21,7 @@ from mindspore.context import ParallelMode
|
|||
import mindspore.nn as nn
|
||||
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from mindspore.communication.management import init
|
||||
from mindspore.communication.management import init, get_rank, get_group_size
|
||||
from mindspore.train.callback import LossMonitor, TimeMonitor
|
||||
from mindspore.train.loss_scale_manager import FixedLossScaleManager
|
||||
from mindspore.common import set_seed
|
||||
|
@ -31,7 +31,7 @@ from src.utils.lr_scheduler import CosineAnnealingLR
|
|||
from src.nets.FCN8s import FCN8s
|
||||
from src.model_utils.config import config
|
||||
from src.model_utils.moxing_adapter import moxing_wrapper
|
||||
from src.model_utils.device_adapter import get_device_id, get_device_num, get_rank_id
|
||||
from src.model_utils.device_adapter import get_device_id, get_device_num
|
||||
|
||||
|
||||
set_seed(1)
|
||||
|
@ -45,7 +45,7 @@ def modelarts_pre_process():
|
|||
def train():
|
||||
device_num = get_device_num()
|
||||
context.set_context(mode=context.GRAPH_MODE, enable_auto_mixed_precision=True, save_graphs=False,
|
||||
device_target='Ascend', device_id=get_device_id())
|
||||
device_target=config.device_target, device_id=get_device_id())
|
||||
# init multicards training
|
||||
config.rank = 0
|
||||
config.group_size = 1
|
||||
|
@ -53,8 +53,8 @@ def train():
|
|||
parallel_mode = ParallelMode.DATA_PARALLEL
|
||||
context.set_auto_parallel_context(parallel_mode=parallel_mode, gradients_mean=True, device_num=device_num)
|
||||
init()
|
||||
config.rank = get_rank_id()
|
||||
config.group_size = get_device_num()
|
||||
config.rank = get_rank()
|
||||
config.group_size = get_group_size()
|
||||
|
||||
# dataset
|
||||
dataset = data_generator.SegDataset(image_mean=config.image_mean,
|
||||
|
@ -112,10 +112,15 @@ def train():
|
|||
|
||||
# loss scale
|
||||
manager_loss_scale = FixedLossScaleManager(config.loss_scale, drop_overflow_update=False)
|
||||
|
||||
optimizer = nn.Momentum(params=net.trainable_params(), learning_rate=lr, momentum=0.9, weight_decay=0.0001,
|
||||
loss_scale=config.loss_scale)
|
||||
model = Model(net, loss_fn=loss_, loss_scale_manager=manager_loss_scale, optimizer=optimizer, amp_level="O3")
|
||||
if config.device_target == "Ascend":
|
||||
optimizer = nn.Momentum(params=net.trainable_params(), learning_rate=lr, momentum=0.9, weight_decay=0.0001,
|
||||
loss_scale=config.loss_scale)
|
||||
model = Model(net, loss_fn=loss_, loss_scale_manager=manager_loss_scale, optimizer=optimizer, amp_level="O3")
|
||||
elif config.device_target == "GPU":
|
||||
optimizer = nn.Momentum(params=net.trainable_params(), learning_rate=lr, momentum=0.9, weight_decay=0.0001)
|
||||
model = Model(net, loss_fn=loss_, optimizer=optimizer)
|
||||
else:
|
||||
raise ValueError("Unsupported platform.")
|
||||
|
||||
# callback for saving ckpts
|
||||
time_cb = TimeMonitor(data_size=iters_per_epoch)
|
||||
|
|
Loading…
Reference in New Issue