This commit is contained in:
huchunmei 2021-05-28 10:34:44 +08:00
parent ae5adc2986
commit d016aa9b94
24 changed files with 808 additions and 233 deletions

View File

@ -86,16 +86,22 @@ For FP16 operators, if the input data type is FP32, the backend of MindSpore wil
├─run_eval_gpu.sh # launch evaluation with gpu platform
└─run_eval.sh # launch evaluating with ascend platform
├─src
├─config.py # parameter configuration
├─dataset.py # data preprocessing
├─inception_v3.py # network definition
├─loss.py # Customized CrossEntropy loss function
├─lr_generator.py # learning rate generator
└─model_utils
├─config.py # Processing configuration parameters
├─device_adapter.py # Get cloud ID
├─local_adapter.py # Get local ID
└─moxing_adapter.py # Parameter processing
├─default_config.yaml # Training parameter profile(ascend)
├─default_config_cpu.yaml # Training parameter profile(cpu)
├─default_config_gpu.yaml # Training parameter profile(gpu)
├─eval.py # eval net
├─export.py # convert checkpoint
├─postprogress.py # post process for 310 inference
└─train.py # train net
```
## [Script Parameters](#contents)

View File

@ -97,16 +97,22 @@ InceptionV3的总体网络架构如下
├─run_eval_gpu.sh # 启动GPU评估
└─run_eval.sh # 启动Ascend评估
├─src
├─config.py # 参数配置
├─dataset.py # 数据预处理
├─inception_v3.py # 网络定义
├─loss.py # 自定义交叉熵损失函数
├─lr_generator.py # 学习率生成器
└─model_utils
├─config.py # 获取.yaml配置参数
├─device_adapter.py # 获取云上id
├─local_adapter.py # 获取本地id
└─moxing_adapter.py # 云上数据准备
├─default_config.yaml # 训练配置参数(ascend)
├─default_config_cpu.yaml # 训练配置参数(cpu)
├─default_config_gpu.yaml # 训练配置参数(gpu)
├─eval.py # 评估网络
├─export.py # 导出 AIR,MINDIR模型的脚本
├─postprogress.py # 310推理后处理脚本
└─train.py # 训练网络
```
## 脚本参数

View File

@ -0,0 +1,72 @@
# Builtin Configurations(DO NOT CHANGE THESE CONFIGURATIONS unless you know exactly what you are doing)
enable_modelarts: False
data_url: ""
train_url: ""
checkpoint_url: ""
data_path: "/cache/data"
output_path: "/cache/train"
load_path: "/cache/checkpoint_path"
device_target: Ascend
enable_profiling: False
# ==============================================================================
dataset_path: "/cache/data"
ckpt_path: '/cache/data/'
checkpoint: './inceptionv3/inceptionv3-rank3_1-247_1251.ckpt'
ckpt_file: '/cache/data/inceptionv3/inceptionv3-rank3_1-247_1251.ckpt'
resume: ''
is_distributed: False
device_id: 0
platform: 'Ascend'
file_name: 'inceptionv3'
file_format: 'AIR'
width: 299
height: 299
modelarts_dataset_unzip_name: 'ImageNet_Original'
need_modelarts_dataset_unzip: True
ckpt_save_dir: './ckpt/'
result_path: '' # "result file path"
label_file": '' # "label file"
# Training options
random_seed: 1
work_nums: 8
decay_method: 'cosine'
loss_scale: 1024
batch_size: 128
epoch_size: 250
num_classes: 1000
ds_type: 'imagenet'
ds_sink_mode: True
smooth_factor: 0.1
aux_factor: 0.2
lr_init: 0.00004
lr_max: 0.4
lr_end: 0.000004
warmup_epochs: 1
weight_decay: 0.00004
momentum: 0.9
opt_eps: 1.0
keep_checkpoint_max: 10
is_save_on_master: 0
dropout_keep_prob: 0.8
has_bias: False
amp_level: 'O3'
---
# Config description for each option
enable_modelarts: 'Whether training on modelarts, default: False'
data_url: 'Dataset url for obs'
train_url: 'Training output url for obs'
data_path: 'Dataset path for local'
output_path: 'Training output path for local'
device_target: 'Target device type'
enable_profiling: 'Whether enable profiling while training, default: False'
file_name: 'output file name.'
file_format: 'file format'
---
device_target: ['Ascend', 'GPU', 'CPU']
file_format: ['AIR', 'ONNX', 'MINDIR']

View File

@ -0,0 +1,73 @@
# Builtin Configurations(DO NOT CHANGE THESE CONFIGURATIONS unless you know exactly what you are doing)
enable_modelarts: False
data_url: ""
train_url: ""
checkpoint_url: ""
data_path: "/cache/data"
output_path: "/cache/train"
load_path: "/cache/checkpoint_path"
device_target: Ascend
enable_profiling: False
# ==============================================================================
dataset_path: "/cache/data"
ckpt_path: '/cache/data/'
checkpoint: './inceptionv3/inceptionv3-rank3_1-247_1251.ckpt'
ckpt_file: '/cache/data/inceptionv3/inceptionv3-rank3_1-247_1251.ckpt'
resume: ''
is_distributed: False
device_id: 0
platform: 'CPU'
file_name: 'inceptionv3'
file_format: 'AIR'
width: 299
height: 299
modelarts_dataset_unzip_name: 'ImageNet_Original'
need_modelarts_dataset_unzip: True
ckpt_save_dir: './ckpt/'
result_path: '' # "result file path"
label_file": '' # "label file"
# Training options
random_seed: 1
work_nums: 8
decay_method: 'cosine'
loss_scale: 1024
batch_size: 128
epoch_size: 120
num_classes: 10
ds_type: 'cifar10'
ds_sink_mode: False
smooth_factor: 0.1
aux_factor: 0.2
lr_init: 0.00004
lr_max: 0.1
lr_end: 0.000004
warmup_epochs: 1
weight_decay: 0.00004
momentum: 0.9
opt_eps: 1.0
keep_checkpoint_max: 10
is_save_on_master: 0
dropout_keep_prob: 0.8
has_bias: False
amp_level: 'O0'
---
# Config description for each option
enable_modelarts: 'Whether training on modelarts, default: False'
data_url: 'Dataset url for obs'
train_url: 'Training output url for obs'
data_path: 'Dataset path for local'
output_path: 'Training output path for local'
device_target: 'Target device type'
enable_profiling: 'Whether enable profiling while training, default: False'
file_name: 'inceptionv3 output air name.'
file_format: 'file format'
---
device_target: ['Ascend', 'GPU', 'CPU']
file_format: ['AIR', 'ONNX', 'MINDIR']

View File

@ -0,0 +1,73 @@
# Builtin Configurations(DO NOT CHANGE THESE CONFIGURATIONS unless you know exactly what you are doing)
enable_modelarts: False
data_url: ""
train_url: ""
checkpoint_url: ""
data_path: "/cache/data"
output_path: "/cache/train"
load_path: "/cache/checkpoint_path"
device_target: Ascend
enable_profiling: False
# ==============================================================================
dataset_path: "/cache/data"
ckpt_path: '/cache/data/'
checkpoint: '/cache/data/inceptionv3/inceptionv3-rank3_1-247_1251.ckpt'
ckpt_file: '/cache/data/inceptionv3/inceptionv3-rank3_1-247_1251.ckpt'
resume: ''
is_distributed: False
device_id: 0
platform: 'GPU'
file_name: 'inceptionv3'
file_format: 'AIR'
width: 299
height: 299
modelarts_dataset_unzip_name: 'ImageNet_Original'
need_modelarts_dataset_unzip: True
ckpt_save_dir: './ckpt/'
result_path: '' # "result file path"
label_file": '' # "label file"
# Training options
random_seed: 1
work_nums: 8
decay_method: 'cosine'
loss_scale: 1
batch_size: 128
epoch_size: 250
num_classes: 1000
ds_type: 'imagenet'
ds_sink_mode: True
smooth_factor: 0.1
aux_factor: 0.2
lr_init: 0.00004
lr_max: 0.4
lr_end: 0.000004
warmup_epochs: 1
weight_decay: 0.00004
momentum: 0.9
opt_eps: 1.0
keep_checkpoint_max: 10
is_save_on_master: 0
dropout_keep_prob: 0.5
has_bias: True
amp_level: 'O0'
---
# Config description for each option
enable_modelarts: 'Whether training on modelarts, default: False'
data_url: 'Dataset url for obs'
train_url: 'Training output url for obs'
data_path: 'Dataset path for local'
output_path: 'Training output path for local'
device_target: 'Target device type'
enable_profiling: 'Whether enable profiling while training, default: False'
file_name: 'output file name.'
file_format: 'file format'
---
device_target: ['Ascend', 'GPU', 'CPU']
file_format: ['AIR', 'ONNX', 'MINDIR']

View File

@ -13,56 +13,110 @@
# limitations under the License.
# ============================================================================
"""evaluate_imagenet"""
import argparse
import os
import time
from src.model_utils.config import config
from src.model_utils.moxing_adapter import moxing_wrapper
from src.model_utils.device_adapter import get_device_id, get_device_num
from src.dataset import create_dataset_imagenet, create_dataset_cifar10
from src.inception_v3 import InceptionV3
from src.loss import CrossEntropy_Val
import mindspore.nn as nn
from mindspore import context
from mindspore.train.model import Model
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from src.config import config_gpu, config_ascend, config_cpu
from src.dataset import create_dataset_imagenet, create_dataset_cifar10
from src.inception_v3 import InceptionV3
from src.loss import CrossEntropy_Val
CFG_DICT = {
"Ascend": config_ascend,
"GPU": config_gpu,
"CPU": config_cpu,
}
DS_DICT = {
"imagenet": create_dataset_imagenet,
"cifar10": create_dataset_cifar10,
}
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='image classification evaluation')
parser.add_argument('--checkpoint', type=str, default='', help='checkpoint of inception-v3 (Default: None)')
parser.add_argument('--dataset_path', type=str, default='', help='Dataset path')
parser.add_argument('--platform', type=str, default='GPU', choices=('Ascend', 'GPU', 'CPU'), help='run platform')
args_opt = parser.parse_args()
if args_opt.platform == 'Ascend':
def modelarts_process():
""" modelarts process """
def unzip(zip_file, save_dir):
import zipfile
s_time = time.time()
if not os.path.exists(os.path.join(save_dir, config.modelarts_dataset_unzip_name)):
zip_isexist = zipfile.is_zipfile(zip_file)
if zip_isexist:
fz = zipfile.ZipFile(zip_file, 'r')
data_num = len(fz.namelist())
print("Extract Start...")
print("unzip file num: {}".format(data_num))
data_print = int(data_num / 100) if data_num > 100 else 1
i = 0
for file in fz.namelist():
if i % data_print == 0:
print("unzip percent: {}%".format(int(i * 100 / data_num)), flush=True)
i += 1
fz.extract(file, save_dir)
print("cost time: {}min:{}s.".format(int((time.time() - s_time) / 60),\
int(int(time.time() - s_time) % 60)))
print("Extract Done")
else:
print("This is not zip.")
else:
print("Zip has been extracted.")
if config.need_modelarts_dataset_unzip:
zip_file_1 = os.path.join(config.data_path, config.modelarts_dataset_unzip_name + ".zip")
save_dir_1 = os.path.join(config.data_path)
sync_lock = "/tmp/unzip_sync.lock"
# Each server contains 8 devices as most
if get_device_id() % min(get_device_num(), 8) == 0 and not os.path.exists(sync_lock):
print("Zip file path: ", zip_file_1)
print("Unzip file save dir: ", save_dir_1)
unzip(zip_file_1, save_dir_1)
print("===Finish extract data synchronization===")
try:
os.mknod(sync_lock)
except IOError:
pass
while True:
if os.path.exists(sync_lock):
break
time.sleep(1)
print("Device: {}, Finish sync unzip data from {} to {}.".format(get_device_id(), zip_file_1, save_dir_1))
print("#" * 200, os.listdir(save_dir_1))
print("#" * 200, os.listdir(os.path.join(config.data_path, config.modelarts_dataset_unzip_name)))
config.dataset_path = os.path.join(config.data_path, config.modelarts_dataset_unzip_name)
config.checkpoint = os.path.join(config.dataset_path, config.checkpoint)
@moxing_wrapper(pre_process=modelarts_process)
def eval_inceptionv3():
if config.platform == 'Ascend':
device_id = int(os.getenv('DEVICE_ID'))
context.set_context(device_id=device_id)
cfg = CFG_DICT[args_opt.platform]
create_dataset = DS_DICT[cfg.ds_type]
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.platform)
net = InceptionV3(num_classes=cfg.num_classes, is_training=False)
ckpt = load_checkpoint(args_opt.checkpoint)
create_dataset = DS_DICT[config.ds_type]
context.set_context(mode=context.GRAPH_MODE, device_target=config.platform)
net = InceptionV3(num_classes=config.num_classes, is_training=False)
ckpt = load_checkpoint(config.checkpoint)
load_param_into_net(net, ckpt)
net.set_train(False)
cfg.rank = 0
cfg.group_size = 1
dataset = create_dataset(args_opt.dataset_path, False, cfg)
loss = CrossEntropy_Val(smooth_factor=0.1, num_classes=cfg.num_classes)
config.rank = 0
config.group_size = 1
dataset = create_dataset(config.dataset_path, False, config)
loss = CrossEntropy_Val(smooth_factor=0.1, num_classes=config.num_classes)
eval_metrics = {'Loss': nn.Loss(),
'Top1-Acc': nn.Top1CategoricalAccuracy(),
'Top5-Acc': nn.Top5CategoricalAccuracy()}
model = Model(net, loss, optimizer=None, metrics=eval_metrics)
metrics = model.eval(dataset, dataset_sink_mode=cfg.ds_sink_mode)
metrics = model.eval(dataset, dataset_sink_mode=config.ds_sink_mode)
print("metric: ", metrics)
if __name__ == '__main__':
config.dataset_path = os.path.join(config.dataset_path, 'validation_preprocess')
eval_inceptionv3()

View File

@ -13,36 +13,27 @@
# limitations under the License.
# ============================================================================
"""export checkpoint file into air, onnx, mindir models"""
import argparse
import numpy as np
from src.model_utils.config import config
from src.model_utils.device_adapter import get_device_id
from src.inception_v3 import InceptionV3
import mindspore as ms
from mindspore import Tensor, load_checkpoint, load_param_into_net, export, context
from src.config import config_gpu as cfg
from src.inception_v3 import InceptionV3
parser = argparse.ArgumentParser(description='inceptionv3 export')
parser.add_argument("--device_id", type=int, default=0, help="Device id")
parser.add_argument("--batch_size", type=int, default=1, help="batch size")
parser.add_argument('--ckpt_file', type=str, required=True, help='inceptionv3 ckpt file.')
parser.add_argument('--file_name', type=str, default='inceptionv3', help='inceptionv3 output air name.')
parser.add_argument('--file_format', type=str, choices=["AIR", "MINDIR"], default='AIR', help='file format')
parser.add_argument('--width', type=int, default=299, help='input width')
parser.add_argument('--height', type=int, default=299, help='input height')
parser.add_argument("--device_target", type=str, choices=["Ascend", "GPU", "CPU"], default="Ascend",
help="device target")
args = parser.parse_args()
config.batch_size = 1
context.set_context(mode=context.GRAPH_MODE, device_target=args.device_target)
if args.device_target == "Ascend":
context.set_context(device_id=args.device_id)
context.set_context(mode=context.GRAPH_MODE, device_target=config.device_target)
if config.device_target == "Ascend":
context.set_context(device_id=get_device_id())
if __name__ == '__main__':
net = InceptionV3(num_classes=cfg.num_classes, is_training=False)
param_dict = load_checkpoint(args.ckpt_file)
net = InceptionV3(num_classes=config.num_classes, is_training=False)
param_dict = load_checkpoint(config.ckpt_file)
load_param_into_net(net, param_dict)
input_arr = Tensor(np.random.uniform(0.0, 1.0, size=[args.batch_size, 3, args.width, args.height]), ms.float32)
export(net, input_arr, file_name=args.file_name, file_format=args.file_format)
input_arr = Tensor(np.random.uniform(0.0, 1.0, size=[config.batch_size, 3, config.width, \
config.height]), ms.float32)
export(net, input_arr, file_name=config.file_name, file_format=config.file_format)

View File

@ -14,13 +14,9 @@
# ============================================================================
'''post process for 310 inference'''
import os
import argparse
import numpy as np
from src.model_utils.config import config
parser = argparse.ArgumentParser(description='fasterrcnn_export')
parser.add_argument("--result_path", type=str, required=True, help="result file path")
parser.add_argument("--label_file", type=str, required=True, help="label file")
args = parser.parse_args()
def read_label(label_file):
f = open(label_file, "r")
@ -55,4 +51,4 @@ def cal_acc(result_path, label_file):
print("========accuracy:{}========".format(accuracy))
if __name__ == "__main__":
cal_acc(args.result_path, args.label_file)
cal_acc(config.result_path, config.label_file)

View File

@ -15,10 +15,14 @@
# ============================================================================
DATA_DIR=$2
CKPT_PATH=$3
export RANK_TABLE_FILE=$1
export RANK_SIZE=8
export HCCL_CONNECT_TIMEOUT=600
BASE_PATH=$(cd ./"`dirname $0`" || exit; pwd)
CONFIG_FILE="${BASE_PATH}/../default_config.yaml"
cores=`cat /proc/cpuinfo|grep "processor" |wc -l`
echo "the number of logical core" $cores
avg_core_per_rank=`expr $cores \/ $RANK_SIZE`
@ -37,14 +41,16 @@ do
rm -rf train_parallel$i
mkdir ./train_parallel$i
cp *.py ./train_parallel$i
cp ../*.py ./train_parallel$i
cp ../*.yaml ./train_parallel$i
cp -r ../src ./train_parallel$i
cd ./train_parallel$i || exit
echo "start training for rank $i, device $DEVICE_ID"
env > env.log
taskset -c $cmdopt python ../train.py \
--is_distributed \
taskset -c $cmdopt python ./train.py --config_path=$CONFIG_FILE \
--is_distributed=True \
--platform=Ascend \
--dataset_path=$DATA_DIR > log.txt 2>&1 &
--dataset_path=$DATA_DIR --ckpt_path=$CKPT_PATH > log.txt 2>&1 &
cd ../
done

View File

@ -14,5 +14,11 @@
# limitations under the License.
# ============================================================================
DATA_DIR=$1
CKPT_PATH=$2
BASE_PATH=$(cd ./"`dirname $0`" || exit; pwd)
CONFIG_FILE="${BASE_PATH}/../default_config_gpu.yaml"
mpirun --allow-run-as-root -n 8 --output-filename log_output --merge-stderr-to-stdout \
python ./train.py --is_distributed --platform 'GPU' --dataset_path $DATA_DIR > train.log 2>&1 &
python ./train.py --config_path=$CONFIG_FILE --is_distributed --platform 'GPU' \
--dataset_path $DATA_DIR --ckpt_path=$CKPT_PATH > train.log 2>&1 &

View File

@ -18,7 +18,10 @@ export DEVICE_ID=$1
DATA_DIR=$2
PATH_CHECKPOINT=$3
python eval.py \
BASE_PATH=$(cd ./"`dirname $0`" || exit; pwd)
CONFIG_FILE="${BASE_PATH}/../default_config.yaml"
python ../eval.py --config_path=$CONFIG_FILE \
--platform=Ascend \
--checkpoint=$PATH_CHECKPOINT \
--dataset_path=$DATA_DIR > eval.log 2>&1 &

View File

@ -15,4 +15,8 @@
# ============================================================================
DATA_DIR=$1
PATH_CHECKPOINT=$2
python ./eval.py --platform 'CPU' --dataset_path $DATA_DIR --checkpoint $PATH_CHECKPOINT > eval.log 2>&1 &
BASE_PATH=$(cd ./"`dirname $0`" || exit; pwd)
CONFIG_FILE="${BASE_PATH}/../default_config_cpu.yaml"
python ../eval.py --config_path=$CONFIG_FILE --platform 'CPU' --dataset_path $DATA_DIR --checkpoint $PATH_CHECKPOINT > eval.log 2>&1 &

View File

@ -16,4 +16,8 @@
DEVICE_ID=$1
DATA_DIR=$2
PATH_CHECKPOINT=$3
CUDA_VISIBLE_DEVICES=$DEVICE_ID python ./eval.py --platform 'GPU' --dataset_path $DATA_DIR --checkpoint $PATH_CHECKPOINT > eval.log 2>&1 &
BASE_PATH=$(cd ./"`dirname $0`" || exit; pwd)
CONFIG_FILE="${BASE_PATH}/../default_config_gpu.yaml"
CUDA_VISIBLE_DEVICES=$DEVICE_ID python ../eval.py --config_path=$CONFIG_FILE --platform 'GPU' --dataset_path $DATA_DIR --checkpoint $PATH_CHECKPOINT > eval.log 2>&1 &

View File

@ -41,6 +41,9 @@ elif [ $# == 3 ]; then
fi
fi
BASE_PATH=$(cd ./"`dirname $0`" || exit; pwd)
CONFIG_FILE="${BASE_PATH}/../default_config.yaml"
echo $model
echo $data_path
echo $label_file
@ -95,7 +98,7 @@ function infer()
function cal_acc()
{
python ../postprocess.py --label_file=$label_file --result_path=result_Files &> acc.log
python ../postprocess.py --config_path=$CONFIG_FILE --label_file=$label_file --result_path=result_Files &> acc.log
if [ $? -ne 0 ]; then
echo "calculate accuracy failed"
exit 1

View File

@ -15,8 +15,12 @@
# ============================================================================
export DEVICE_ID=$1
DATA_DIR=$2
python train.py \
--platform=Ascend \
--dataset_path=$DATA_DIR > log.txt 2>&1 &
export DATA_DIR=$2
export CKPT_PATH=$3
BASE_PATH=$(cd ./"`dirname $0`" || exit; pwd)
CONFIG_FILE="${BASE_PATH}/../default_config.yaml"
python train.py --config_path=$CONFIG_FILE \
--platform=Ascend \
--dataset_path=$DATA_DIR --ckpt_path=$CKPT_PATH > log.txt 2>&1 &

View File

@ -14,5 +14,10 @@
# limitations under the License.
# ============================================================================
DATA_DIR=$1
python ./train.py --platform 'CPU' --dataset_path $DATA_DIR > train.log 2>&1 &
CKPT_PATH=$2
BASE_PATH=$(cd ./"`dirname $0`" || exit; pwd)
CONFIG_FILE="${BASE_PATH}/../default_config_cpu.yaml"
python ./train.py --config_path=$CONFIG_FILE --platform 'CPU' --dataset_path $DATA_DIR \
--ckpt_path=$CKPT_PATH > train.log 2>&1 &

View File

@ -15,5 +15,10 @@
# ============================================================================
DEVICE_ID=$1
DATA_DIR=$2
CUDA_VISIBLE_DEVICES=$DEVICE_ID python ./train.py --platform 'GPU' --dataset_path $DATA_DIR > train.log 2>&1 &
CKPT_PATH=$3
BASE_PATH=$(cd ./"`dirname $0`" || exit; pwd)
CONFIG_FILE="${BASE_PATH}/../default_config_gpu.yaml"
CUDA_VISIBLE_DEVICES=$DEVICE_ID python ./train.py --config_path=$CONFIG_FILE --platform 'GPU' \
--dataset_path $DATA_DIR --ckpt_path=$CKPT_PATH > train.log 2>&1 &

View File

@ -1,100 +0,0 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
network config setting, will be used in main.py
"""
from easydict import EasyDict as edict
config_gpu = edict({
'random_seed': 1,
'work_nums': 8,
'decay_method': 'cosine',
"loss_scale": 1,
'batch_size': 128,
'epoch_size': 250,
'num_classes': 1000,
'ds_type': 'imagenet',
'ds_sink_mode': True,
'smooth_factor': 0.1,
'aux_factor': 0.2,
'lr_init': 0.00004,
'lr_max': 0.4,
'lr_end': 0.000004,
'warmup_epochs': 1,
'weight_decay': 0.00004,
'momentum': 0.9,
'opt_eps': 1.0,
'keep_checkpoint_max': 10,
'ckpt_path': './',
'is_save_on_master': 0,
'dropout_keep_prob': 0.5,
'has_bias': True,
'amp_level': 'O0'
})
config_ascend = edict({
'random_seed': 1,
'work_nums': 8,
'decay_method': 'cosine',
"loss_scale": 1024,
'batch_size': 128,
'epoch_size': 250,
'num_classes': 1000,
'ds_type': 'imagenet',
'ds_sink_mode': True,
'smooth_factor': 0.1,
'aux_factor': 0.2,
'lr_init': 0.00004,
'lr_max': 0.4,
'lr_end': 0.000004,
'warmup_epochs': 1,
'weight_decay': 0.00004,
'momentum': 0.9,
'opt_eps': 1.0,
'keep_checkpoint_max': 10,
'ckpt_path': './',
'is_save_on_master': 0,
'dropout_keep_prob': 0.8,
'has_bias': False,
'amp_level': 'O3'
})
config_cpu = edict({
'random_seed': 1,
'work_nums': 8,
'decay_method': 'cosine',
"loss_scale": 1024,
'batch_size': 128,
'epoch_size': 120,
'num_classes': 10,
'ds_type': 'cifar10',
'ds_sink_mode': False,
'smooth_factor': 0.1,
'aux_factor': 0.2,
'lr_init': 0.00004,
'lr_max': 0.1,
'lr_end': 0.000004,
'warmup_epochs': 1,
'weight_decay': 0.00004,
'momentum': 0.9,
'opt_eps': 1.0,
'keep_checkpoint_max': 10,
'ckpt_path': './',
'is_save_on_master': 0,
'dropout_keep_prob': 0.8,
'has_bias': False,
'amp_level': 'O0',
})

View File

@ -0,0 +1,127 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Parse arguments"""
import os
import ast
import argparse
from pprint import pprint, pformat
import yaml
class Config:
"""
Configuration namespace. Convert dictionary to members.
"""
def __init__(self, cfg_dict):
for k, v in cfg_dict.items():
if isinstance(v, (list, tuple)):
setattr(self, k, [Config(x) if isinstance(x, dict) else x for x in v])
else:
setattr(self, k, Config(v) if isinstance(v, dict) else v)
def __str__(self):
return pformat(self.__dict__)
def __repr__(self):
return self.__str__()
def parse_cli_to_yaml(parser, cfg, helper=None, choices=None, cfg_path="default_config.yaml"):
"""
Parse command line arguments to the configuration according to the default yaml.
Args:
parser: Parent parser.
cfg: Base configuration.
helper: Helper description.
cfg_path: Path to the default yaml config.
"""
parser = argparse.ArgumentParser(description="[REPLACE THIS at config.py]",
parents=[parser])
helper = {} if helper is None else helper
choices = {} if choices is None else choices
for item in cfg:
if not isinstance(cfg[item], list) and not isinstance(cfg[item], dict):
help_description = helper[item] if item in helper else "Please reference to {}".format(cfg_path)
choice = choices[item] if item in choices else None
if isinstance(cfg[item], bool):
parser.add_argument("--" + item, type=ast.literal_eval, default=cfg[item], choices=choice,
help=help_description)
else:
parser.add_argument("--" + item, type=type(cfg[item]), default=cfg[item], choices=choice,
help=help_description)
args = parser.parse_args()
return args
def parse_yaml(yaml_path):
"""
Parse the yaml config file.
Args:
yaml_path: Path to the yaml config.
"""
with open(yaml_path, 'r') as fin:
try:
cfgs = yaml.load_all(fin.read(), Loader=yaml.FullLoader)
cfgs = [x for x in cfgs]
if len(cfgs) == 1:
cfg_helper = {}
cfg = cfgs[0]
cfg_choices = {}
elif len(cfgs) == 2:
cfg, cfg_helper = cfgs
cfg_choices = {}
elif len(cfgs) == 3:
cfg, cfg_helper, cfg_choices = cfgs
else:
raise ValueError("At most 3 docs (config, description for help, choices) are supported in config yaml")
print(cfg_helper)
except:
raise ValueError("Failed to parse yaml")
return cfg, cfg_helper, cfg_choices
def merge(args, cfg):
"""
Merge the base config from yaml file and command line arguments.
Args:
args: Command line arguments.
cfg: Base configuration.
"""
args_var = vars(args)
for item in args_var:
cfg[item] = args_var[item]
return cfg
def get_config():
"""
Get Config according to the yaml file and cli arguments.
"""
parser = argparse.ArgumentParser(description="default name", add_help=False)
current_dir = os.path.dirname(os.path.abspath(__file__))
parser.add_argument("--config_path", type=str, default=os.path.join(current_dir, "../../default_config.yaml"),
help="Config file path")
path_args, _ = parser.parse_known_args()
default, helper, choices = parse_yaml(path_args.config_path)
pprint(default)
args = parse_cli_to_yaml(parser=parser, cfg=default, helper=helper, choices=choices, cfg_path=path_args.config_path)
final_config = merge(args, default)
return Config(final_config)
config = get_config()

View File

@ -0,0 +1,27 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Device adapter for ModelArts"""
from .config import config
if config.enable_modelarts:
from .moxing_adapter import get_device_id, get_device_num, get_rank_id, get_job_id
else:
from .local_adapter import get_device_id, get_device_num, get_rank_id, get_job_id
__all__ = [
"get_device_id", "get_device_num", "get_rank_id", "get_job_id"
]

View File

@ -0,0 +1,36 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Local adapter"""
import os
def get_device_id():
device_id = os.getenv('DEVICE_ID', '0')
return int(device_id)
def get_device_num():
device_num = os.getenv('RANK_SIZE', '1')
return int(device_num)
def get_rank_id():
global_rank_id = os.getenv('RANK_ID', '0')
return int(global_rank_id)
def get_job_id():
return "Local Job"

View File

@ -0,0 +1,122 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Moxing adapter for ModelArts"""
import os
import functools
from mindspore import context
from mindspore.profiler import Profiler
from .config import config
_global_sync_count = 0
def get_device_id():
device_id = os.getenv('DEVICE_ID', '0')
return int(device_id)
def get_device_num():
device_num = os.getenv('RANK_SIZE', '1')
return int(device_num)
def get_rank_id():
global_rank_id = os.getenv('RANK_ID', '0')
return int(global_rank_id)
def get_job_id():
job_id = os.getenv('JOB_ID')
job_id = job_id if job_id != "" else "default"
return job_id
def sync_data(from_path, to_path):
"""
Download data from remote obs to local directory if the first url is remote url and the second one is local path
Upload data from local directory to remote obs in contrast.
"""
import moxing as mox
import time
global _global_sync_count
sync_lock = "/tmp/copy_sync.lock" + str(_global_sync_count)
_global_sync_count += 1
# Each server contains 8 devices as most.
if get_device_id() % min(get_device_num(), 8) == 0 and not os.path.exists(sync_lock):
print("from path: ", from_path)
print("to path: ", to_path)
mox.file.copy_parallel(from_path, to_path)
print("===finish data synchronization===")
try:
os.mknod(sync_lock)
except IOError:
pass
print("===save flag===")
while True:
if os.path.exists(sync_lock):
break
time.sleep(1)
print("Finish sync data from {} to {}.".format(from_path, to_path))
def moxing_wrapper(pre_process=None, post_process=None):
"""
Moxing wrapper to download dataset and upload outputs.
"""
def wrapper(run_func):
@functools.wraps(run_func)
def wrapped_func(*args, **kwargs):
# Download data from data_url
if config.enable_modelarts:
if config.data_url:
sync_data(config.data_url, config.data_path)
print("Dataset downloaded: ", os.listdir(config.data_path))
if config.checkpoint_url:
sync_data(config.checkpoint_url, config.load_path)
print("Preload downloaded: ", os.listdir(config.load_path))
if config.train_url:
sync_data(config.train_url, config.output_path)
print("Workspace downloaded: ", os.listdir(config.output_path))
context.set_context(save_graphs_path=os.path.join(config.output_path, str(get_rank_id())))
config.device_num = get_device_num()
config.device_id = get_device_id()
if not os.path.exists(config.output_path):
os.makedirs(config.output_path)
if pre_process:
pre_process()
if config.enable_profiling:
profiler = Profiler()
run_func(*args, **kwargs)
if config.enable_profiling:
profiler.analyse()
# Upload data to train_url
if config.enable_modelarts:
if post_process:
post_process()
if config.train_url:
print("Start to copy output directory")
sync_data(config.output_path, config.train_url)
return wrapped_func
return wrapper

View File

@ -13,10 +13,17 @@
# limitations under the License.
# ============================================================================
"""train_imagenet."""
import argparse
import time
import os
import mindspore.nn as nn
from src.model_utils.config import config
from src.model_utils.moxing_adapter import moxing_wrapper
from src.model_utils.device_adapter import get_device_id, get_device_num
from src.dataset import create_dataset_imagenet, create_dataset_cifar10
from src.inception_v3 import InceptionV3
from src.lr_generator import get_lr
from src.loss import CrossEntropy
from mindspore import Tensor
from mindspore import context
from mindspore.context import ParallelMode
@ -29,69 +36,110 @@ from mindspore.train.loss_scale_manager import FixedLossScaleManager
from mindspore.common.initializer import XavierUniform, initializer
from mindspore.common import set_seed
from src.config import config_gpu, config_ascend, config_cpu
from src.dataset import create_dataset_imagenet, create_dataset_cifar10
from src.inception_v3 import InceptionV3
from src.lr_generator import get_lr
from src.loss import CrossEntropy
set_seed(1)
CFG_DICT = {
"Ascend": config_ascend,
"GPU": config_gpu,
"CPU": config_cpu,
}
DS_DICT = {
"imagenet": create_dataset_imagenet,
"cifar10": create_dataset_cifar10,
}
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='image classification training')
parser.add_argument('--dataset_path', type=str, default='', help='Dataset path')
parser.add_argument('--resume', type=str, default='', help='resume training with existed checkpoint')
parser.add_argument('--is_distributed', action='store_true', default=False,
help='distributed training')
parser.add_argument('--platform', type=str, default='GPU', choices=('Ascend', 'GPU', 'CPU'), help='run platform')
args_opt = parser.parse_args()
cfg = CFG_DICT[args_opt.platform]
create_dataset = DS_DICT[cfg.ds_type]
def modelarts_pre_process():
def unzip(zip_file, save_dir):
import zipfile
s_time = time.time()
if not os.path.exists(os.path.join(save_dir, config.modelarts_dataset_unzip_name)):
zip_isexist = zipfile.is_zipfile(zip_file)
if zip_isexist:
fz = zipfile.ZipFile(zip_file, 'r')
data_num = len(fz.namelist())
print("Extract Start...")
print("unzip file num: {}".format(data_num))
data_print = int(data_num / 100) if data_num > 100 else 1
i = 0
for file in fz.namelist():
if i % data_print == 0:
print("unzip percent: {}%".format(int(i * 100 / data_num)), flush=True)
i += 1
fz.extract(file, save_dir)
print("cost time: {}min:{}s.".format(int((time.time() - s_time) / 60),\
int(int(time.time() - s_time) % 60)))
print("Extract Done")
else:
print("This is not zip.")
else:
print("Zip has been extracted.")
if args_opt.platform == "GPU":
if config.need_modelarts_dataset_unzip:
zip_file_1 = os.path.join(config.data_path, config.modelarts_dataset_unzip_name + ".zip")
save_dir_1 = os.path.join(config.data_path)
sync_lock = "/tmp/unzip_sync.lock"
# Each server contains 8 devices as most
if get_device_id() % min(get_device_num(), 8) == 0 and not os.path.exists(sync_lock):
print("Zip file path: ", zip_file_1)
print("Unzip file save dir: ", save_dir_1)
unzip(zip_file_1, save_dir_1)
print("===Finish extract data synchronization===")
try:
os.mknod(sync_lock)
except IOError:
pass
while True:
if os.path.exists(sync_lock):
break
time.sleep(1)
print("Device: {}, Finish sync unzip data from {} to {}.".format(get_device_id(), zip_file_1, save_dir_1))
print("#" * 200, os.listdir(save_dir_1))
print("#" * 200, os.listdir(os.path.join(config.data_path, config.modelarts_dataset_unzip_name)))
config.dataset_path = os.path.join(config.data_path, config.modelarts_dataset_unzip_name)
config.ckpt_path = config.output_path
@moxing_wrapper(pre_process=modelarts_pre_process)
def train_inceptionv3():
print(config)
config.dataset_path = os.path.join(config.dataset_path, 'train')
create_dataset = DS_DICT[config.ds_type]
if config.platform == "GPU":
context.set_context(enable_graph_kernel=True)
context.set_context(mode=context.GRAPH_MODE, device_target=args_opt.platform, save_graphs=False)
context.set_context(mode=context.GRAPH_MODE, device_target=config.platform, save_graphs=False)
if os.getenv('DEVICE_ID', "not_set").isdigit():
context.set_context(device_id=int(os.getenv('DEVICE_ID')))
# init distributed
if args_opt.is_distributed:
if config.is_distributed:
init()
cfg.rank = get_rank()
cfg.group_size = get_group_size()
config.rank = get_rank()
config.group_size = get_group_size()
parallel_mode = ParallelMode.DATA_PARALLEL
context.set_auto_parallel_context(parallel_mode=parallel_mode, device_num=cfg.group_size,
context.set_auto_parallel_context(parallel_mode=parallel_mode, device_num=config.group_size,
gradients_mean=True)
else:
cfg.rank = 0
cfg.group_size = 1
config.rank = 0
config.group_size = 1
# dataloader
dataset = create_dataset(args_opt.dataset_path, True, cfg)
dataset = create_dataset(config.dataset_path, True, config)
batches_per_epoch = dataset.get_dataset_size()
# network
net = InceptionV3(num_classes=cfg.num_classes, dropout_keep_prob=cfg.dropout_keep_prob, has_bias=cfg.has_bias)
net = InceptionV3(num_classes=config.num_classes, dropout_keep_prob=config.dropout_keep_prob, \
has_bias=config.has_bias)
# loss
loss = CrossEntropy(smooth_factor=cfg.smooth_factor, num_classes=cfg.num_classes, factor=cfg.aux_factor)
loss = CrossEntropy(smooth_factor=config.smooth_factor, num_classes=config.num_classes, factor=config.aux_factor)
# learning rate schedule
lr = get_lr(lr_init=cfg.lr_init, lr_end=cfg.lr_end, lr_max=cfg.lr_max, warmup_epochs=cfg.warmup_epochs,
total_epochs=cfg.epoch_size, steps_per_epoch=batches_per_epoch, lr_decay_mode=cfg.decay_method)
lr = get_lr(lr_init=config.lr_init, lr_end=config.lr_end, lr_max=config.lr_max, warmup_epochs=config.warmup_epochs,
total_epochs=config.epoch_size, steps_per_epoch=batches_per_epoch, lr_decay_mode=config.decay_method)
lr = Tensor(lr)
# optimizer
@ -103,41 +151,45 @@ if __name__ == '__main__':
else:
no_decayed_params.append(param)
if args_opt.platform == "Ascend":
if config.platform == "Ascend":
for param in net.trainable_params():
if 'beta' not in param.name and 'gamma' not in param.name and 'bias' not in param.name:
param.set_data(initializer(XavierUniform(), param.data.shape, param.data.dtype))
group_params = [{'params': decayed_params, 'weight_decay': cfg.weight_decay},
group_params = [{'params': decayed_params, 'weight_decay': config.weight_decay},
{'params': no_decayed_params},
{'order_params': net.trainable_params()}]
optimizer = RMSProp(group_params, lr, decay=0.9, weight_decay=cfg.weight_decay,
momentum=cfg.momentum, epsilon=cfg.opt_eps, loss_scale=cfg.loss_scale)
eval_metrics = {'Loss': nn.Loss(),
'Top1-Acc': nn.Top1CategoricalAccuracy(),
'Top5-Acc': nn.Top5CategoricalAccuracy()}
optimizer = RMSProp(group_params, lr, decay=0.9, weight_decay=config.weight_decay,
momentum=config.momentum, epsilon=config.opt_eps, loss_scale=config.loss_scale)
# eval_metrics = {'Loss': nn.Loss(), 'Top1-Acc': nn.Top1CategoricalAccuracy(), \
# 'Top5-Acc': nn.Top5CategoricalAccuracy()}
if args_opt.resume:
ckpt = load_checkpoint(args_opt.resume)
if config.resume:
ckpt = load_checkpoint(config.resume)
load_param_into_net(net, ckpt)
if args_opt.platform == "Ascend":
loss_scale_manager = FixedLossScaleManager(cfg.loss_scale, drop_overflow_update=False)
model = Model(net, loss_fn=loss, optimizer=optimizer, metrics={'acc'}, amp_level=cfg.amp_level,
if config.platform == "Ascend":
loss_scale_manager = FixedLossScaleManager(config.loss_scale, drop_overflow_update=False)
model = Model(net, loss_fn=loss, optimizer=optimizer, metrics={'acc'}, amp_level=config.amp_level,
loss_scale_manager=loss_scale_manager)
else:
model = Model(net, loss_fn=loss, optimizer=optimizer, metrics={'acc'}, amp_level=cfg.amp_level)
model = Model(net, loss_fn=loss, optimizer=optimizer, metrics={'acc'}, amp_level=config.amp_level)
print("============== Starting Training ==============")
loss_cb = LossMonitor(per_print_times=batches_per_epoch)
time_cb = TimeMonitor(data_size=batches_per_epoch)
callbacks = [loss_cb, time_cb]
config_ck = CheckpointConfig(save_checkpoint_steps=batches_per_epoch, keep_checkpoint_max=cfg.keep_checkpoint_max)
save_ckpt_path = os.path.join(cfg.ckpt_path, 'ckpt_' + str(cfg.rank) + '/')
ckpoint_cb = ModelCheckpoint(prefix=f"inceptionv3-rank{cfg.rank}", directory=save_ckpt_path, config=config_ck)
if args_opt.is_distributed & cfg.is_save_on_master:
if cfg.rank == 0:
config_ck = CheckpointConfig(save_checkpoint_steps=batches_per_epoch, \
keep_checkpoint_max=config.keep_checkpoint_max)
save_ckpt_path = os.path.join(config.ckpt_path, 'ckpt_' + str(config.rank) + '/')
ckpoint_cb = ModelCheckpoint(prefix=f"inceptionv3-rank{config.rank}", directory=save_ckpt_path, config=config_ck)
if config.is_distributed & config.is_save_on_master:
if config.rank == 0:
callbacks.append(ckpoint_cb)
model.train(cfg.epoch_size, dataset, callbacks=callbacks, dataset_sink_mode=cfg.ds_sink_mode)
model.train(config.epoch_size, dataset, callbacks=callbacks, dataset_sink_mode=config.ds_sink_mode)
else:
callbacks.append(ckpoint_cb)
model.train(cfg.epoch_size, dataset, callbacks=callbacks, dataset_sink_mode=cfg.ds_sink_mode)
model.train(config.epoch_size, dataset, callbacks=callbacks, dataset_sink_mode=config.ds_sink_mode)
print("train success")
if __name__ == '__main__':
train_inceptionv3()