!15913 inceptionv4 support cpu training

From: @caojian05
Reviewed-by: @wuxuejian
Signed-off-by: @wuxuejian
This commit is contained in:
mindspore-ci-bot 2021-05-12 15:57:32 +08:00 committed by Gitee
commit 23fc8506d2
9 changed files with 275 additions and 88 deletions

View File

@ -35,7 +35,9 @@ void BatchNormCPUKernel::InitKernel(const CNodePtr &kernel_node) {
is_train = AnfAlgo::GetNodeAttr<bool>(kernel_node, "is_training");
momentum = AnfAlgo::GetNodeAttr<float>(kernel_node, "momentum");
std::vector<size_t> x_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
if (x_shape.size() != 4) {
if (x_shape.size() == 2) {
x_shape.insert(x_shape.end(), 2, 1);
} else if (x_shape.size() != 4) {
MS_LOG(EXCEPTION) << "Batchnorm only support nchw input!";
}
batch_size = x_shape[0];

View File

@ -38,7 +38,9 @@ void BatchNormGradCPUKernel::InitInputOutputSize(const CNodePtr &kernel_node) {
void BatchNormGradCPUKernel::InitKernel(const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL(kernel_node);
std::vector<size_t> x_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
if (x_shape.size() != 4) {
if (x_shape.size() == 2) {
x_shape.insert(x_shape.end(), 2, 1);
} else if (x_shape.size() != 4) {
MS_LOG(EXCEPTION) << "Fused batchnorm only support nchw input!";
}
batch_size = x_shape[0];

View File

@ -71,6 +71,8 @@ For FP16 operators, if the input data type is FP32, the backend of MindSpore wil
├─scripts
├─run_distribute_train_gpu.sh # launch distributed training with gpu platform(8p)
├─run_eval_gpu.sh # launch evaluating with gpu platform
├─run_eval_cpu.sh # launch evaluating with cpu platform
├─run_standalone_train_cpu.sh # launch standalone training with cpu platform(1p)
├─run_standalone_train_ascend.sh # launch standalone training with ascend platform(1p)
├─run_distribute_train_ascend.sh # launch distributed training with ascend platform(8p)
├─run_infer_310.sh # shell script for 310 inference
@ -138,6 +140,13 @@ sh scripts/run_standalone_train_ascend.sh DEVICE_ID DATA_DIR
sh scripts/run_distribute_train_gpu.sh DATA_PATH
```
- CPU:
```bash
# standalone training example with shell
sh scripts/run_standalone_train_cpu.sh DATA_PATH
```
### Launch
```bash
@ -151,6 +160,9 @@ sh scripts/run_distribute_train_gpu.sh DATA_PATH
GPU:
# distribute training example(8p)
sh scripts/run_distribute_train_gpu.sh DATA_PATH
CPU:
# standalone training example with shell
sh scripts/run_standalone_train_cpu.sh DATA_PATH
```
### Result

View File

@ -18,23 +18,35 @@ import os
import mindspore.nn as nn
from mindspore import context
from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits
from mindspore.train.model import Model
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits
from src.dataset import create_dataset
from src.config import config_ascend, config_gpu, config_cpu
from src.dataset import create_dataset_imagenet, create_dataset_cifar10
from src.inceptionv4 import Inceptionv4
from src.config import config
CFG_DICT = {
"Ascend": config_ascend,
"GPU": config_gpu,
"CPU": config_cpu,
}
DS_DICT = {
"imagenet": create_dataset_imagenet,
"cifar10": create_dataset_cifar10,
}
def parse_args():
'''parse_args'''
parser = argparse.ArgumentParser(description='image classification evaluation')
parser.add_argument('--platform', type=str, default='Ascend', choices=('Ascend', 'GPU'), help='run platform')
parser.add_argument('--platform', type=str, default='Ascend', choices=('Ascend', 'GPU', 'CPU'), help='run platform')
parser.add_argument('--dataset_path', type=str, default='', help='Dataset path')
parser.add_argument('--checkpoint_path', type=str, default='', help='checkpoint of inceptionV4')
args_opt = parser.parse_args()
return args_opt
if __name__ == '__main__':
args = parse_args()
@ -42,18 +54,22 @@ if __name__ == '__main__':
device_id = int(os.getenv('DEVICE_ID', '0'))
context.set_context(device_id=device_id)
config = CFG_DICT[args.platform]
create_dataset = DS_DICT[config.ds_type]
context.set_context(mode=context.GRAPH_MODE, device_target=args.platform)
net = Inceptionv4(classes=config.num_classes)
ckpt = load_checkpoint(args.checkpoint_path)
load_param_into_net(net, ckpt)
net.set_train(False)
dataset = create_dataset(dataset_path=args.dataset_path, do_train=False,
repeat_num=1, batch_size=config.batch_size)
config.rank = 0
config.group_size = 1
dataset = create_dataset(dataset_path=args.dataset_path, do_train=False, cfg=config)
loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
eval_metrics = {'Loss': nn.Loss(),
'Top1-Acc': nn.Top1CategoricalAccuracy(),
'Top5-Acc': nn.Top5CategoricalAccuracy()}
model = Model(net, loss, optimizer=None, metrics=eval_metrics)
print('=' * 20, 'Evalute start', '=' * 20)
metrics = model.eval(dataset)
metrics = model.eval(dataset, dataset_sink_mode=config.ds_sink_mode)
print("metric: ", metrics)

View File

@ -0,0 +1,28 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
rm -rf evaluation
mkdir evaluation
cp ./*.py ./evaluation
cp -r ./src ./evaluation
cd ./evaluation || exit
DATA_DIR=$1
CKPT_DIR=$2
echo "start evaluation"
python eval.py --dataset_path=$DATA_DIR --checkpoint_path=$CKPT_DIR --platform='CPU' > eval.log 2>&1 &

View File

@ -0,0 +1,25 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
DATA_DIR=$1
rm -rf train_standalone
mkdir ./train_standalone
cd ./train_standalone || exit
env > env.log
python -u ../train.py \
--dataset_path=$DATA_DIR --platform=CPU> log.txt 2>&1 &
cd ../

View File

@ -17,13 +17,15 @@ network config setting, will be used in main.py
"""
from easydict import EasyDict as edict
config = edict({
config_ascend = edict({
'is_save_on_master': False,
'batch_size': 128,
'epoch_size': 250,
'num_classes': 1000,
'work_nums': 8,
'ds_type': 'imagenet',
'ds_sink_mode': True,
'loss_scale': 1024,
'smooth_factor': 0.1,
@ -42,3 +44,57 @@ config = edict({
'warmup_epochs': 1,
'start_epoch': 1,
})
config_gpu = edict({
'is_save_on_master': False,
'batch_size': 128,
'epoch_size': 250,
'num_classes': 1000,
'work_nums': 8,
'ds_type': 'imagenet',
'ds_sink_mode': True,
'loss_scale': 1024,
'smooth_factor': 0.1,
'weight_decay': 0.00004,
'momentum': 0.9,
'amp_level': 'O0',
'decay': 0.9,
'epsilon': 1.0,
'keep_checkpoint_max': 10,
'save_checkpoint_epochs': 10,
'lr_init': 0.00004,
'lr_end': 0.000004,
'lr_max': 0.4,
'warmup_epochs': 1,
'start_epoch': 1,
})
config_cpu = edict({
'batch_size': 128,
'epoch_size': 250,
'num_classes': 10,
'work_nums': 8,
'ds_type': 'cifar10',
'ds_sink_mode': False,
'loss_scale': 1024,
'smooth_factor': 0.1,
'weight_decay': 0.00004,
'momentum': 0.9,
'amp_level': 'O0',
'decay': 0.9,
'epsilon': 1.0,
'keep_checkpoint_max': 10,
'save_checkpoint_epochs': 10,
'lr_init': 0.00004,
'lr_end': 0.000004,
'lr_max': 0.4,
'warmup_epochs': 1,
'start_epoch': 1,
})

View File

@ -12,68 +12,100 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Create train or eval dataset."""
import os
import mindspore.common.dtype as mstype
import mindspore.dataset as de
import mindspore.dataset.vision.c_transforms as C
import mindspore.dataset.transforms.c_transforms as C2
from src.config import config
device_id = int(os.getenv('DEVICE_ID', '0'))
device_num = int(os.getenv('RANK_SIZE', '1'))
def create_dataset(dataset_path, do_train, repeat_num=1, batch_size=32, shard_id=0):
"""
Create a train or eval dataset.
Data operations, will be used in train.py and eval.py
"""
import os
import mindspore.common.dtype as mstype
import mindspore.dataset as ds
import mindspore.dataset.transforms.c_transforms as C2
import mindspore.dataset.vision.c_transforms as C
def create_dataset_imagenet(dataset_path, do_train, cfg, repeat_num=1):
"""
create a train or eval dataset
Args:
dataset_path (str): The path of dataset.
do_train (bool): Whether dataset is used for train or eval.
repeat_num (int): The repeat times of dataset. Default: 1.
batch_size (int): The batch size of dataset. Default: 32.
dataset_path(string): the path of dataset.
do_train(bool): whether dataset is used for train or eval.
cfg (dict): the config for creating dataset.
repeat_num(int): the repeat times of dataset. Default: 1.
Returns:
Dataset.
dataset
"""
do_shuffle = bool(do_train)
if device_num == 1 or not do_train:
ds = de.ImageFolderDataset(dataset_path, num_parallel_workers=config.work_nums, shuffle=do_shuffle)
if cfg.group_size == 1:
data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=cfg.work_nums, shuffle=do_train)
else:
ds = de.ImageFolderDataset(dataset_path, num_parallel_workers=config.work_nums,
shuffle=do_shuffle, num_shards=device_num, shard_id=shard_id)
image_length = 299
data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=cfg.work_nums, shuffle=do_train,
num_shards=cfg.group_size, shard_id=cfg.rank)
# define map operations
size = 299
if do_train:
trans = [
C.RandomCropDecodeResize(image_length, scale=(0.08, 1.0), ratio=(0.75, 1.333)),
C.RandomCropDecodeResize(size, scale=(0.08, 1.0), ratio=(0.75, 1.333)),
C.RandomHorizontalFlip(prob=0.5),
C.RandomColorAdjust(brightness=0.4, contrast=0.4, saturation=0.4)
]
else:
trans = [
C.Decode(),
C.Resize(image_length),
C.CenterCrop(image_length)
C.Resize(size),
C.CenterCrop(size)
]
trans += [
C.Rescale(1.0 / 255.0, 0.0),
C.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
C.HWC2CHW()
]
type_cast_op = C2.TypeCast(mstype.int32)
data_set = data_set.map(operations=trans, input_columns="image", num_parallel_workers=cfg.work_nums)
data_set = data_set.map(operations=type_cast_op, input_columns="label", num_parallel_workers=cfg.work_nums)
# apply batch operations
data_set = data_set.batch(cfg.batch_size, drop_remainder=True)
# apply dataset repeat operation
data_set = data_set.repeat(repeat_num)
return data_set
def create_dataset_cifar10(dataset_path, do_train, cfg, repeat_num=1):
"""
create a train or eval dataset
Args:
dataset_path(string): the path of dataset.
do_train(bool): whether dataset is used for train or eval.
cfg (dict): the config for creating dataset.
repeat_num(int): the repeat times of dataset. Default: 1.
Returns:
dataset
"""
dataset_path = os.path.join(dataset_path, "cifar-10-batches-bin" if do_train else "cifar-10-verify-bin")
if cfg.group_size == 1:
data_set = ds.Cifar10Dataset(dataset_path, num_parallel_workers=cfg.work_nums, shuffle=do_train)
else:
data_set = ds.Cifar10Dataset(dataset_path, num_parallel_workers=cfg.work_nums, shuffle=do_train,
num_shards=cfg.group_size, shard_id=cfg.rank)
# define map operations
trans = []
if do_train:
trans.append(C.RandomCrop((32, 32), (4, 4, 4, 4)))
trans.append(C.RandomHorizontalFlip(prob=0.5))
trans.append(C.Resize((299, 299)))
trans.append(C.Rescale(1.0 / 255.0, 0.0))
trans.append(C.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010]))
trans.append(C.HWC2CHW())
type_cast_op = C2.TypeCast(mstype.int32)
ds = ds.map(input_columns="label", operations=type_cast_op, num_parallel_workers=config.work_nums)
ds = ds.map(input_columns="image", operations=trans, num_parallel_workers=config.work_nums)
data_set = data_set.map(operations=trans, input_columns="image", num_parallel_workers=cfg.work_nums)
data_set = data_set.map(operations=type_cast_op, input_columns="label", num_parallel_workers=cfg.work_nums)
# apply batch operations
ds = ds.batch(batch_size, drop_remainder=True)
data_set = data_set.batch(cfg.batch_size, drop_remainder=do_train)
# apply dataset repeat operation
ds = ds.repeat(repeat_num)
return ds
data_set = data_set.repeat(repeat_num)
return data_set

View File

@ -13,32 +13,63 @@
# limitations under the License.
# ============================================================================
"""train imagenet"""
import os
import argparse
import math
import os
import numpy as np
from mindspore.communication import init, get_rank
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMonitor, LossMonitor
from mindspore.train.model import ParallelMode
from mindspore.train.loss_scale_manager import FixedLossScaleManager
from mindspore import Model
from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits
from mindspore.nn import RMSProp
from mindspore import Tensor
from mindspore import context
from mindspore.common import set_seed
from mindspore.common.initializer import XavierUniform, initializer
from mindspore.communication import init, get_rank, get_group_size
from mindspore.nn import RMSProp
from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMonitor, LossMonitor
from mindspore.train.loss_scale_manager import FixedLossScaleManager
from mindspore.train.model import ParallelMode
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from src.config import config_ascend, config_gpu, config_cpu
from src.dataset import create_dataset_imagenet, create_dataset_cifar10
from src.inceptionv4 import Inceptionv4
from src.dataset import create_dataset, device_num
from src.config import config
os.environ['PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION'] = 'python'
set_seed(1)
CFG_DICT = {
"Ascend": config_ascend,
"GPU": config_gpu,
"CPU": config_cpu,
}
DS_DICT = {
"imagenet": create_dataset_imagenet,
"cifar10": create_dataset_cifar10,
}
device_num = int(os.getenv('RANK_SIZE', '1'))
def parse_args():
'''parse_args'''
arg_parser = argparse.ArgumentParser(description='InceptionV4 image classification training')
arg_parser.add_argument('--dataset_path', type=str, default='', help='Dataset path')
arg_parser.add_argument('--device_id', type=int, default=0, help='device id')
arg_parser.add_argument('--platform', type=str, default='Ascend', choices=("Ascend", "GPU", "CPU"),
help='Platform, support Ascend, GPU, CPU.')
arg_parser.add_argument('--resume', type=str, default='', help='resume training with existed checkpoint')
args_opt = arg_parser.parse_args()
return args_opt
args = parse_args()
config = CFG_DICT[args.platform]
create_dataset = DS_DICT[config.ds_type]
def generate_cosine_lr(steps_per_epoch, total_epochs,
lr_init=config.lr_init,
lr_end=config.lr_end,
@ -87,7 +118,6 @@ def inception_v4_train():
context.set_context(device_id=args.device_id)
context.set_context(enable_graph_kernel=False)
rank = 0
if device_num > 1:
if args.platform == "Ascend":
init(backend_name='hccl')
@ -96,15 +126,18 @@ def inception_v4_train():
else:
raise ValueError("Unsupported device target.")
rank = get_rank()
config.rank = get_rank()
config.group_size = get_group_size()
context.set_auto_parallel_context(device_num=device_num,
parallel_mode=ParallelMode.DATA_PARALLEL,
gradients_mean=True,
all_reduce_fusion_config=[200, 400])
else:
config.rank = 0
config.group_size = 1
# create dataset
train_dataset = create_dataset(dataset_path=args.dataset_path, do_train=True,
repeat_num=1, batch_size=config.batch_size, shard_id=rank)
train_dataset = create_dataset(dataset_path=args.dataset_path, do_train=True, cfg=config)
train_step_size = train_dataset.get_dataset_size()
# create model
@ -140,23 +173,16 @@ def inception_v4_train():
loss_scale_manager = FixedLossScaleManager(config.loss_scale, drop_overflow_update=False)
if args.platform == "Ascend":
model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc', 'top_1_accuracy', 'top_5_accuracy'},
loss_scale_manager=loss_scale_manager, amp_level=config.amp_level)
elif args.platform == "GPU":
model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc', 'top_1_accuracy', 'top_5_accuracy'},
loss_scale_manager=loss_scale_manager, amp_level='O0')
else:
raise ValueError("Unsupported device target.")
# define callbacks
performance_cb = TimeMonitor(data_size=train_step_size)
loss_cb = LossMonitor(per_print_times=train_step_size)
ckp_save_step = config.save_checkpoint_epochs * train_step_size
config_ck = CheckpointConfig(save_checkpoint_steps=ckp_save_step, keep_checkpoint_max=config.keep_checkpoint_max)
ckpoint_cb = ModelCheckpoint(prefix=f"inceptionV4-train-rank{rank}",
directory='ckpts_rank_' + str(rank), config=config_ck)
ckpoint_cb = ModelCheckpoint(prefix=f"inceptionV4-train-rank{config.rank}",
directory='ckpts_rank_' + str(config.rank), config=config_ck)
callbacks = [performance_cb, loss_cb]
if device_num > 1 and config.is_save_on_master:
if args.device_id == 0:
@ -165,21 +191,9 @@ def inception_v4_train():
callbacks.append(ckpoint_cb)
# train model
model.train(config.epoch_size, train_dataset, callbacks=callbacks, dataset_sink_mode=True)
def parse_args():
'''parse_args'''
arg_parser = argparse.ArgumentParser(description='InceptionV4 image classification training')
arg_parser.add_argument('--dataset_path', type=str, default='', help='Dataset path')
arg_parser.add_argument('--device_id', type=int, default=0, help='device id')
arg_parser.add_argument('--platform', type=str, default='Ascend', choices=("Ascend", "GPU"),
help='Platform, support Ascend, GPU.')
arg_parser.add_argument('--resume', type=str, default='', help='resume training with existed checkpoint')
args_opt = arg_parser.parse_args()
return args_opt
model.train(config.epoch_size, train_dataset, callbacks=callbacks, dataset_sink_mode=config.ds_sink_mode)
if __name__ == '__main__':
args = parse_args()
inception_v4_train()
print('Inceptionv4 training success!')