forked from mindspore-Ecosystem/mindspore
!15913 inceptionv4 support cpu training
From: @caojian05 Reviewed-by: @wuxuejian Signed-off-by: @wuxuejian
This commit is contained in:
commit
23fc8506d2
|
@ -35,7 +35,9 @@ void BatchNormCPUKernel::InitKernel(const CNodePtr &kernel_node) {
|
||||||
is_train = AnfAlgo::GetNodeAttr<bool>(kernel_node, "is_training");
|
is_train = AnfAlgo::GetNodeAttr<bool>(kernel_node, "is_training");
|
||||||
momentum = AnfAlgo::GetNodeAttr<float>(kernel_node, "momentum");
|
momentum = AnfAlgo::GetNodeAttr<float>(kernel_node, "momentum");
|
||||||
std::vector<size_t> x_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
|
std::vector<size_t> x_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
|
||||||
if (x_shape.size() != 4) {
|
if (x_shape.size() == 2) {
|
||||||
|
x_shape.insert(x_shape.end(), 2, 1);
|
||||||
|
} else if (x_shape.size() != 4) {
|
||||||
MS_LOG(EXCEPTION) << "Batchnorm only support nchw input!";
|
MS_LOG(EXCEPTION) << "Batchnorm only support nchw input!";
|
||||||
}
|
}
|
||||||
batch_size = x_shape[0];
|
batch_size = x_shape[0];
|
||||||
|
|
|
@ -38,7 +38,9 @@ void BatchNormGradCPUKernel::InitInputOutputSize(const CNodePtr &kernel_node) {
|
||||||
void BatchNormGradCPUKernel::InitKernel(const CNodePtr &kernel_node) {
|
void BatchNormGradCPUKernel::InitKernel(const CNodePtr &kernel_node) {
|
||||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||||
std::vector<size_t> x_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
|
std::vector<size_t> x_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
|
||||||
if (x_shape.size() != 4) {
|
if (x_shape.size() == 2) {
|
||||||
|
x_shape.insert(x_shape.end(), 2, 1);
|
||||||
|
} else if (x_shape.size() != 4) {
|
||||||
MS_LOG(EXCEPTION) << "Fused batchnorm only support nchw input!";
|
MS_LOG(EXCEPTION) << "Fused batchnorm only support nchw input!";
|
||||||
}
|
}
|
||||||
batch_size = x_shape[0];
|
batch_size = x_shape[0];
|
||||||
|
|
|
@ -71,6 +71,8 @@ For FP16 operators, if the input data type is FP32, the backend of MindSpore wil
|
||||||
├─scripts
|
├─scripts
|
||||||
├─run_distribute_train_gpu.sh # launch distributed training with gpu platform(8p)
|
├─run_distribute_train_gpu.sh # launch distributed training with gpu platform(8p)
|
||||||
├─run_eval_gpu.sh # launch evaluating with gpu platform
|
├─run_eval_gpu.sh # launch evaluating with gpu platform
|
||||||
|
├─run_eval_cpu.sh # launch evaluating with cpu platform
|
||||||
|
├─run_standalone_train_cpu.sh # launch standalone training with cpu platform(1p)
|
||||||
├─run_standalone_train_ascend.sh # launch standalone training with ascend platform(1p)
|
├─run_standalone_train_ascend.sh # launch standalone training with ascend platform(1p)
|
||||||
├─run_distribute_train_ascend.sh # launch distributed training with ascend platform(8p)
|
├─run_distribute_train_ascend.sh # launch distributed training with ascend platform(8p)
|
||||||
├─run_infer_310.sh # shell script for 310 inference
|
├─run_infer_310.sh # shell script for 310 inference
|
||||||
|
@ -138,6 +140,13 @@ sh scripts/run_standalone_train_ascend.sh DEVICE_ID DATA_DIR
|
||||||
sh scripts/run_distribute_train_gpu.sh DATA_PATH
|
sh scripts/run_distribute_train_gpu.sh DATA_PATH
|
||||||
```
|
```
|
||||||
|
|
||||||
|
- CPU:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# standalone training example with shell
|
||||||
|
sh scripts/run_standalone_train_cpu.sh DATA_PATH
|
||||||
|
```
|
||||||
|
|
||||||
### Launch
|
### Launch
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
|
@ -151,6 +160,9 @@ sh scripts/run_distribute_train_gpu.sh DATA_PATH
|
||||||
GPU:
|
GPU:
|
||||||
# distribute training example(8p)
|
# distribute training example(8p)
|
||||||
sh scripts/run_distribute_train_gpu.sh DATA_PATH
|
sh scripts/run_distribute_train_gpu.sh DATA_PATH
|
||||||
|
CPU:
|
||||||
|
# standalone training example with shell
|
||||||
|
sh scripts/run_standalone_train_cpu.sh DATA_PATH
|
||||||
```
|
```
|
||||||
|
|
||||||
### Result
|
### Result
|
||||||
|
|
|
@ -18,23 +18,35 @@ import os
|
||||||
|
|
||||||
import mindspore.nn as nn
|
import mindspore.nn as nn
|
||||||
from mindspore import context
|
from mindspore import context
|
||||||
|
from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits
|
||||||
from mindspore.train.model import Model
|
from mindspore.train.model import Model
|
||||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||||
from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits
|
from src.config import config_ascend, config_gpu, config_cpu
|
||||||
|
from src.dataset import create_dataset_imagenet, create_dataset_cifar10
|
||||||
from src.dataset import create_dataset
|
|
||||||
from src.inceptionv4 import Inceptionv4
|
from src.inceptionv4 import Inceptionv4
|
||||||
from src.config import config
|
|
||||||
|
CFG_DICT = {
|
||||||
|
"Ascend": config_ascend,
|
||||||
|
"GPU": config_gpu,
|
||||||
|
"CPU": config_cpu,
|
||||||
|
}
|
||||||
|
|
||||||
|
DS_DICT = {
|
||||||
|
"imagenet": create_dataset_imagenet,
|
||||||
|
"cifar10": create_dataset_cifar10,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def parse_args():
|
def parse_args():
|
||||||
'''parse_args'''
|
'''parse_args'''
|
||||||
parser = argparse.ArgumentParser(description='image classification evaluation')
|
parser = argparse.ArgumentParser(description='image classification evaluation')
|
||||||
parser.add_argument('--platform', type=str, default='Ascend', choices=('Ascend', 'GPU'), help='run platform')
|
parser.add_argument('--platform', type=str, default='Ascend', choices=('Ascend', 'GPU', 'CPU'), help='run platform')
|
||||||
parser.add_argument('--dataset_path', type=str, default='', help='Dataset path')
|
parser.add_argument('--dataset_path', type=str, default='', help='Dataset path')
|
||||||
parser.add_argument('--checkpoint_path', type=str, default='', help='checkpoint of inceptionV4')
|
parser.add_argument('--checkpoint_path', type=str, default='', help='checkpoint of inceptionV4')
|
||||||
args_opt = parser.parse_args()
|
args_opt = parser.parse_args()
|
||||||
return args_opt
|
return args_opt
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
args = parse_args()
|
args = parse_args()
|
||||||
|
|
||||||
|
@ -42,18 +54,22 @@ if __name__ == '__main__':
|
||||||
device_id = int(os.getenv('DEVICE_ID', '0'))
|
device_id = int(os.getenv('DEVICE_ID', '0'))
|
||||||
context.set_context(device_id=device_id)
|
context.set_context(device_id=device_id)
|
||||||
|
|
||||||
|
config = CFG_DICT[args.platform]
|
||||||
|
create_dataset = DS_DICT[config.ds_type]
|
||||||
|
|
||||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.platform)
|
context.set_context(mode=context.GRAPH_MODE, device_target=args.platform)
|
||||||
net = Inceptionv4(classes=config.num_classes)
|
net = Inceptionv4(classes=config.num_classes)
|
||||||
ckpt = load_checkpoint(args.checkpoint_path)
|
ckpt = load_checkpoint(args.checkpoint_path)
|
||||||
load_param_into_net(net, ckpt)
|
load_param_into_net(net, ckpt)
|
||||||
net.set_train(False)
|
net.set_train(False)
|
||||||
dataset = create_dataset(dataset_path=args.dataset_path, do_train=False,
|
config.rank = 0
|
||||||
repeat_num=1, batch_size=config.batch_size)
|
config.group_size = 1
|
||||||
|
dataset = create_dataset(dataset_path=args.dataset_path, do_train=False, cfg=config)
|
||||||
loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
|
loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
|
||||||
eval_metrics = {'Loss': nn.Loss(),
|
eval_metrics = {'Loss': nn.Loss(),
|
||||||
'Top1-Acc': nn.Top1CategoricalAccuracy(),
|
'Top1-Acc': nn.Top1CategoricalAccuracy(),
|
||||||
'Top5-Acc': nn.Top5CategoricalAccuracy()}
|
'Top5-Acc': nn.Top5CategoricalAccuracy()}
|
||||||
model = Model(net, loss, optimizer=None, metrics=eval_metrics)
|
model = Model(net, loss, optimizer=None, metrics=eval_metrics)
|
||||||
print('=' * 20, 'Evalute start', '=' * 20)
|
print('=' * 20, 'Evalute start', '=' * 20)
|
||||||
metrics = model.eval(dataset)
|
metrics = model.eval(dataset, dataset_sink_mode=config.ds_sink_mode)
|
||||||
print("metric: ", metrics)
|
print("metric: ", metrics)
|
||||||
|
|
|
@ -0,0 +1,28 @@
|
||||||
|
#!/bin/bash
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
rm -rf evaluation
|
||||||
|
mkdir evaluation
|
||||||
|
cp ./*.py ./evaluation
|
||||||
|
cp -r ./src ./evaluation
|
||||||
|
cd ./evaluation || exit
|
||||||
|
|
||||||
|
DATA_DIR=$1
|
||||||
|
CKPT_DIR=$2
|
||||||
|
|
||||||
|
echo "start evaluation"
|
||||||
|
|
||||||
|
python eval.py --dataset_path=$DATA_DIR --checkpoint_path=$CKPT_DIR --platform='CPU' > eval.log 2>&1 &
|
|
@ -0,0 +1,25 @@
|
||||||
|
#!/bin/bash
|
||||||
|
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
DATA_DIR=$1
|
||||||
|
|
||||||
|
rm -rf train_standalone
|
||||||
|
mkdir ./train_standalone
|
||||||
|
cd ./train_standalone || exit
|
||||||
|
env > env.log
|
||||||
|
python -u ../train.py \
|
||||||
|
--dataset_path=$DATA_DIR --platform=CPU> log.txt 2>&1 &
|
||||||
|
cd ../
|
|
@ -17,13 +17,15 @@ network config setting, will be used in main.py
|
||||||
"""
|
"""
|
||||||
from easydict import EasyDict as edict
|
from easydict import EasyDict as edict
|
||||||
|
|
||||||
config = edict({
|
config_ascend = edict({
|
||||||
'is_save_on_master': False,
|
'is_save_on_master': False,
|
||||||
|
|
||||||
'batch_size': 128,
|
'batch_size': 128,
|
||||||
'epoch_size': 250,
|
'epoch_size': 250,
|
||||||
'num_classes': 1000,
|
'num_classes': 1000,
|
||||||
'work_nums': 8,
|
'work_nums': 8,
|
||||||
|
'ds_type': 'imagenet',
|
||||||
|
'ds_sink_mode': True,
|
||||||
|
|
||||||
'loss_scale': 1024,
|
'loss_scale': 1024,
|
||||||
'smooth_factor': 0.1,
|
'smooth_factor': 0.1,
|
||||||
|
@ -42,3 +44,57 @@ config = edict({
|
||||||
'warmup_epochs': 1,
|
'warmup_epochs': 1,
|
||||||
'start_epoch': 1,
|
'start_epoch': 1,
|
||||||
})
|
})
|
||||||
|
|
||||||
|
config_gpu = edict({
|
||||||
|
'is_save_on_master': False,
|
||||||
|
|
||||||
|
'batch_size': 128,
|
||||||
|
'epoch_size': 250,
|
||||||
|
'num_classes': 1000,
|
||||||
|
'work_nums': 8,
|
||||||
|
'ds_type': 'imagenet',
|
||||||
|
'ds_sink_mode': True,
|
||||||
|
|
||||||
|
'loss_scale': 1024,
|
||||||
|
'smooth_factor': 0.1,
|
||||||
|
'weight_decay': 0.00004,
|
||||||
|
'momentum': 0.9,
|
||||||
|
'amp_level': 'O0',
|
||||||
|
'decay': 0.9,
|
||||||
|
'epsilon': 1.0,
|
||||||
|
|
||||||
|
'keep_checkpoint_max': 10,
|
||||||
|
'save_checkpoint_epochs': 10,
|
||||||
|
|
||||||
|
'lr_init': 0.00004,
|
||||||
|
'lr_end': 0.000004,
|
||||||
|
'lr_max': 0.4,
|
||||||
|
'warmup_epochs': 1,
|
||||||
|
'start_epoch': 1,
|
||||||
|
})
|
||||||
|
|
||||||
|
config_cpu = edict({
|
||||||
|
'batch_size': 128,
|
||||||
|
'epoch_size': 250,
|
||||||
|
'num_classes': 10,
|
||||||
|
'work_nums': 8,
|
||||||
|
'ds_type': 'cifar10',
|
||||||
|
'ds_sink_mode': False,
|
||||||
|
|
||||||
|
'loss_scale': 1024,
|
||||||
|
'smooth_factor': 0.1,
|
||||||
|
'weight_decay': 0.00004,
|
||||||
|
'momentum': 0.9,
|
||||||
|
'amp_level': 'O0',
|
||||||
|
'decay': 0.9,
|
||||||
|
'epsilon': 1.0,
|
||||||
|
|
||||||
|
'keep_checkpoint_max': 10,
|
||||||
|
'save_checkpoint_epochs': 10,
|
||||||
|
|
||||||
|
'lr_init': 0.00004,
|
||||||
|
'lr_end': 0.000004,
|
||||||
|
'lr_max': 0.4,
|
||||||
|
'warmup_epochs': 1,
|
||||||
|
'start_epoch': 1,
|
||||||
|
})
|
||||||
|
|
|
@ -12,68 +12,100 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
"""Create train or eval dataset."""
|
|
||||||
import os
|
|
||||||
import mindspore.common.dtype as mstype
|
|
||||||
import mindspore.dataset as de
|
|
||||||
import mindspore.dataset.vision.c_transforms as C
|
|
||||||
import mindspore.dataset.transforms.c_transforms as C2
|
|
||||||
from src.config import config
|
|
||||||
|
|
||||||
|
|
||||||
device_id = int(os.getenv('DEVICE_ID', '0'))
|
|
||||||
device_num = int(os.getenv('RANK_SIZE', '1'))
|
|
||||||
|
|
||||||
|
|
||||||
def create_dataset(dataset_path, do_train, repeat_num=1, batch_size=32, shard_id=0):
|
|
||||||
"""
|
"""
|
||||||
Create a train or eval dataset.
|
Data operations, will be used in train.py and eval.py
|
||||||
|
"""
|
||||||
|
import os
|
||||||
|
|
||||||
|
import mindspore.common.dtype as mstype
|
||||||
|
import mindspore.dataset as ds
|
||||||
|
import mindspore.dataset.transforms.c_transforms as C2
|
||||||
|
import mindspore.dataset.vision.c_transforms as C
|
||||||
|
|
||||||
|
|
||||||
|
def create_dataset_imagenet(dataset_path, do_train, cfg, repeat_num=1):
|
||||||
|
"""
|
||||||
|
create a train or eval dataset
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
dataset_path (str): The path of dataset.
|
dataset_path(string): the path of dataset.
|
||||||
do_train (bool): Whether dataset is used for train or eval.
|
do_train(bool): whether dataset is used for train or eval.
|
||||||
repeat_num (int): The repeat times of dataset. Default: 1.
|
cfg (dict): the config for creating dataset.
|
||||||
batch_size (int): The batch size of dataset. Default: 32.
|
repeat_num(int): the repeat times of dataset. Default: 1.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Dataset.
|
dataset
|
||||||
"""
|
"""
|
||||||
|
if cfg.group_size == 1:
|
||||||
do_shuffle = bool(do_train)
|
data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=cfg.work_nums, shuffle=do_train)
|
||||||
|
|
||||||
if device_num == 1 or not do_train:
|
|
||||||
ds = de.ImageFolderDataset(dataset_path, num_parallel_workers=config.work_nums, shuffle=do_shuffle)
|
|
||||||
else:
|
else:
|
||||||
ds = de.ImageFolderDataset(dataset_path, num_parallel_workers=config.work_nums,
|
data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=cfg.work_nums, shuffle=do_train,
|
||||||
shuffle=do_shuffle, num_shards=device_num, shard_id=shard_id)
|
num_shards=cfg.group_size, shard_id=cfg.rank)
|
||||||
|
# define map operations
|
||||||
image_length = 299
|
size = 299
|
||||||
if do_train:
|
if do_train:
|
||||||
trans = [
|
trans = [
|
||||||
C.RandomCropDecodeResize(image_length, scale=(0.08, 1.0), ratio=(0.75, 1.333)),
|
C.RandomCropDecodeResize(size, scale=(0.08, 1.0), ratio=(0.75, 1.333)),
|
||||||
C.RandomHorizontalFlip(prob=0.5),
|
C.RandomHorizontalFlip(prob=0.5),
|
||||||
C.RandomColorAdjust(brightness=0.4, contrast=0.4, saturation=0.4)
|
C.RandomColorAdjust(brightness=0.4, contrast=0.4, saturation=0.4)
|
||||||
]
|
]
|
||||||
else:
|
else:
|
||||||
trans = [
|
trans = [
|
||||||
C.Decode(),
|
C.Decode(),
|
||||||
C.Resize(image_length),
|
C.Resize(size),
|
||||||
C.CenterCrop(image_length)
|
C.CenterCrop(size)
|
||||||
]
|
]
|
||||||
trans += [
|
trans += [
|
||||||
C.Rescale(1.0 / 255.0, 0.0),
|
C.Rescale(1.0 / 255.0, 0.0),
|
||||||
C.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
C.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
||||||
C.HWC2CHW()
|
C.HWC2CHW()
|
||||||
]
|
]
|
||||||
|
type_cast_op = C2.TypeCast(mstype.int32)
|
||||||
|
data_set = data_set.map(operations=trans, input_columns="image", num_parallel_workers=cfg.work_nums)
|
||||||
|
data_set = data_set.map(operations=type_cast_op, input_columns="label", num_parallel_workers=cfg.work_nums)
|
||||||
|
# apply batch operations
|
||||||
|
data_set = data_set.batch(cfg.batch_size, drop_remainder=True)
|
||||||
|
# apply dataset repeat operation
|
||||||
|
data_set = data_set.repeat(repeat_num)
|
||||||
|
return data_set
|
||||||
|
|
||||||
|
|
||||||
|
def create_dataset_cifar10(dataset_path, do_train, cfg, repeat_num=1):
|
||||||
|
"""
|
||||||
|
create a train or eval dataset
|
||||||
|
|
||||||
|
Args:
|
||||||
|
dataset_path(string): the path of dataset.
|
||||||
|
do_train(bool): whether dataset is used for train or eval.
|
||||||
|
cfg (dict): the config for creating dataset.
|
||||||
|
repeat_num(int): the repeat times of dataset. Default: 1.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dataset
|
||||||
|
"""
|
||||||
|
dataset_path = os.path.join(dataset_path, "cifar-10-batches-bin" if do_train else "cifar-10-verify-bin")
|
||||||
|
if cfg.group_size == 1:
|
||||||
|
data_set = ds.Cifar10Dataset(dataset_path, num_parallel_workers=cfg.work_nums, shuffle=do_train)
|
||||||
|
else:
|
||||||
|
data_set = ds.Cifar10Dataset(dataset_path, num_parallel_workers=cfg.work_nums, shuffle=do_train,
|
||||||
|
num_shards=cfg.group_size, shard_id=cfg.rank)
|
||||||
|
|
||||||
|
# define map operations
|
||||||
|
trans = []
|
||||||
|
if do_train:
|
||||||
|
trans.append(C.RandomCrop((32, 32), (4, 4, 4, 4)))
|
||||||
|
trans.append(C.RandomHorizontalFlip(prob=0.5))
|
||||||
|
|
||||||
|
trans.append(C.Resize((299, 299)))
|
||||||
|
trans.append(C.Rescale(1.0 / 255.0, 0.0))
|
||||||
|
trans.append(C.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010]))
|
||||||
|
trans.append(C.HWC2CHW())
|
||||||
|
|
||||||
type_cast_op = C2.TypeCast(mstype.int32)
|
type_cast_op = C2.TypeCast(mstype.int32)
|
||||||
|
data_set = data_set.map(operations=trans, input_columns="image", num_parallel_workers=cfg.work_nums)
|
||||||
ds = ds.map(input_columns="label", operations=type_cast_op, num_parallel_workers=config.work_nums)
|
data_set = data_set.map(operations=type_cast_op, input_columns="label", num_parallel_workers=cfg.work_nums)
|
||||||
ds = ds.map(input_columns="image", operations=trans, num_parallel_workers=config.work_nums)
|
|
||||||
|
|
||||||
# apply batch operations
|
# apply batch operations
|
||||||
ds = ds.batch(batch_size, drop_remainder=True)
|
data_set = data_set.batch(cfg.batch_size, drop_remainder=do_train)
|
||||||
|
|
||||||
# apply dataset repeat operation
|
# apply dataset repeat operation
|
||||||
ds = ds.repeat(repeat_num)
|
data_set = data_set.repeat(repeat_num)
|
||||||
return ds
|
return data_set
|
||||||
|
|
|
@ -13,32 +13,63 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ============================================================================
|
# ============================================================================
|
||||||
"""train imagenet"""
|
"""train imagenet"""
|
||||||
import os
|
|
||||||
import argparse
|
import argparse
|
||||||
import math
|
import math
|
||||||
|
import os
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from mindspore.communication import init, get_rank
|
|
||||||
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMonitor, LossMonitor
|
|
||||||
from mindspore.train.model import ParallelMode
|
|
||||||
from mindspore.train.loss_scale_manager import FixedLossScaleManager
|
|
||||||
from mindspore import Model
|
from mindspore import Model
|
||||||
from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits
|
|
||||||
from mindspore.nn import RMSProp
|
|
||||||
from mindspore import Tensor
|
from mindspore import Tensor
|
||||||
from mindspore import context
|
from mindspore import context
|
||||||
from mindspore.common import set_seed
|
from mindspore.common import set_seed
|
||||||
from mindspore.common.initializer import XavierUniform, initializer
|
from mindspore.common.initializer import XavierUniform, initializer
|
||||||
|
from mindspore.communication import init, get_rank, get_group_size
|
||||||
|
from mindspore.nn import RMSProp
|
||||||
|
from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits
|
||||||
|
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMonitor, LossMonitor
|
||||||
|
from mindspore.train.loss_scale_manager import FixedLossScaleManager
|
||||||
|
from mindspore.train.model import ParallelMode
|
||||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||||
|
from src.config import config_ascend, config_gpu, config_cpu
|
||||||
|
from src.dataset import create_dataset_imagenet, create_dataset_cifar10
|
||||||
from src.inceptionv4 import Inceptionv4
|
from src.inceptionv4 import Inceptionv4
|
||||||
from src.dataset import create_dataset, device_num
|
|
||||||
|
|
||||||
from src.config import config
|
|
||||||
|
|
||||||
os.environ['PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION'] = 'python'
|
os.environ['PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION'] = 'python'
|
||||||
set_seed(1)
|
set_seed(1)
|
||||||
|
|
||||||
|
CFG_DICT = {
|
||||||
|
"Ascend": config_ascend,
|
||||||
|
"GPU": config_gpu,
|
||||||
|
"CPU": config_cpu,
|
||||||
|
}
|
||||||
|
|
||||||
|
DS_DICT = {
|
||||||
|
"imagenet": create_dataset_imagenet,
|
||||||
|
"cifar10": create_dataset_cifar10,
|
||||||
|
}
|
||||||
|
|
||||||
|
device_num = int(os.getenv('RANK_SIZE', '1'))
|
||||||
|
|
||||||
|
|
||||||
|
def parse_args():
|
||||||
|
'''parse_args'''
|
||||||
|
arg_parser = argparse.ArgumentParser(description='InceptionV4 image classification training')
|
||||||
|
arg_parser.add_argument('--dataset_path', type=str, default='', help='Dataset path')
|
||||||
|
arg_parser.add_argument('--device_id', type=int, default=0, help='device id')
|
||||||
|
arg_parser.add_argument('--platform', type=str, default='Ascend', choices=("Ascend", "GPU", "CPU"),
|
||||||
|
help='Platform, support Ascend, GPU, CPU.')
|
||||||
|
arg_parser.add_argument('--resume', type=str, default='', help='resume training with existed checkpoint')
|
||||||
|
args_opt = arg_parser.parse_args()
|
||||||
|
return args_opt
|
||||||
|
|
||||||
|
|
||||||
|
args = parse_args()
|
||||||
|
|
||||||
|
config = CFG_DICT[args.platform]
|
||||||
|
create_dataset = DS_DICT[config.ds_type]
|
||||||
|
|
||||||
|
|
||||||
def generate_cosine_lr(steps_per_epoch, total_epochs,
|
def generate_cosine_lr(steps_per_epoch, total_epochs,
|
||||||
lr_init=config.lr_init,
|
lr_init=config.lr_init,
|
||||||
lr_end=config.lr_end,
|
lr_end=config.lr_end,
|
||||||
|
@ -87,7 +118,6 @@ def inception_v4_train():
|
||||||
context.set_context(device_id=args.device_id)
|
context.set_context(device_id=args.device_id)
|
||||||
context.set_context(enable_graph_kernel=False)
|
context.set_context(enable_graph_kernel=False)
|
||||||
|
|
||||||
rank = 0
|
|
||||||
if device_num > 1:
|
if device_num > 1:
|
||||||
if args.platform == "Ascend":
|
if args.platform == "Ascend":
|
||||||
init(backend_name='hccl')
|
init(backend_name='hccl')
|
||||||
|
@ -96,15 +126,18 @@ def inception_v4_train():
|
||||||
else:
|
else:
|
||||||
raise ValueError("Unsupported device target.")
|
raise ValueError("Unsupported device target.")
|
||||||
|
|
||||||
rank = get_rank()
|
config.rank = get_rank()
|
||||||
|
config.group_size = get_group_size()
|
||||||
context.set_auto_parallel_context(device_num=device_num,
|
context.set_auto_parallel_context(device_num=device_num,
|
||||||
parallel_mode=ParallelMode.DATA_PARALLEL,
|
parallel_mode=ParallelMode.DATA_PARALLEL,
|
||||||
gradients_mean=True,
|
gradients_mean=True,
|
||||||
all_reduce_fusion_config=[200, 400])
|
all_reduce_fusion_config=[200, 400])
|
||||||
|
else:
|
||||||
|
config.rank = 0
|
||||||
|
config.group_size = 1
|
||||||
|
|
||||||
# create dataset
|
# create dataset
|
||||||
train_dataset = create_dataset(dataset_path=args.dataset_path, do_train=True,
|
train_dataset = create_dataset(dataset_path=args.dataset_path, do_train=True, cfg=config)
|
||||||
repeat_num=1, batch_size=config.batch_size, shard_id=rank)
|
|
||||||
train_step_size = train_dataset.get_dataset_size()
|
train_step_size = train_dataset.get_dataset_size()
|
||||||
|
|
||||||
# create model
|
# create model
|
||||||
|
@ -140,23 +173,16 @@ def inception_v4_train():
|
||||||
|
|
||||||
loss_scale_manager = FixedLossScaleManager(config.loss_scale, drop_overflow_update=False)
|
loss_scale_manager = FixedLossScaleManager(config.loss_scale, drop_overflow_update=False)
|
||||||
|
|
||||||
|
|
||||||
if args.platform == "Ascend":
|
|
||||||
model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc', 'top_1_accuracy', 'top_5_accuracy'},
|
model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc', 'top_1_accuracy', 'top_5_accuracy'},
|
||||||
loss_scale_manager=loss_scale_manager, amp_level=config.amp_level)
|
loss_scale_manager=loss_scale_manager, amp_level=config.amp_level)
|
||||||
elif args.platform == "GPU":
|
|
||||||
model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc', 'top_1_accuracy', 'top_5_accuracy'},
|
|
||||||
loss_scale_manager=loss_scale_manager, amp_level='O0')
|
|
||||||
else:
|
|
||||||
raise ValueError("Unsupported device target.")
|
|
||||||
|
|
||||||
# define callbacks
|
# define callbacks
|
||||||
performance_cb = TimeMonitor(data_size=train_step_size)
|
performance_cb = TimeMonitor(data_size=train_step_size)
|
||||||
loss_cb = LossMonitor(per_print_times=train_step_size)
|
loss_cb = LossMonitor(per_print_times=train_step_size)
|
||||||
ckp_save_step = config.save_checkpoint_epochs * train_step_size
|
ckp_save_step = config.save_checkpoint_epochs * train_step_size
|
||||||
config_ck = CheckpointConfig(save_checkpoint_steps=ckp_save_step, keep_checkpoint_max=config.keep_checkpoint_max)
|
config_ck = CheckpointConfig(save_checkpoint_steps=ckp_save_step, keep_checkpoint_max=config.keep_checkpoint_max)
|
||||||
ckpoint_cb = ModelCheckpoint(prefix=f"inceptionV4-train-rank{rank}",
|
ckpoint_cb = ModelCheckpoint(prefix=f"inceptionV4-train-rank{config.rank}",
|
||||||
directory='ckpts_rank_' + str(rank), config=config_ck)
|
directory='ckpts_rank_' + str(config.rank), config=config_ck)
|
||||||
callbacks = [performance_cb, loss_cb]
|
callbacks = [performance_cb, loss_cb]
|
||||||
if device_num > 1 and config.is_save_on_master:
|
if device_num > 1 and config.is_save_on_master:
|
||||||
if args.device_id == 0:
|
if args.device_id == 0:
|
||||||
|
@ -165,21 +191,9 @@ def inception_v4_train():
|
||||||
callbacks.append(ckpoint_cb)
|
callbacks.append(ckpoint_cb)
|
||||||
|
|
||||||
# train model
|
# train model
|
||||||
model.train(config.epoch_size, train_dataset, callbacks=callbacks, dataset_sink_mode=True)
|
model.train(config.epoch_size, train_dataset, callbacks=callbacks, dataset_sink_mode=config.ds_sink_mode)
|
||||||
|
|
||||||
def parse_args():
|
|
||||||
'''parse_args'''
|
|
||||||
arg_parser = argparse.ArgumentParser(description='InceptionV4 image classification training')
|
|
||||||
arg_parser.add_argument('--dataset_path', type=str, default='', help='Dataset path')
|
|
||||||
arg_parser.add_argument('--device_id', type=int, default=0, help='device id')
|
|
||||||
arg_parser.add_argument('--platform', type=str, default='Ascend', choices=("Ascend", "GPU"),
|
|
||||||
help='Platform, support Ascend, GPU.')
|
|
||||||
arg_parser.add_argument('--resume', type=str, default='', help='resume training with existed checkpoint')
|
|
||||||
args_opt = arg_parser.parse_args()
|
|
||||||
return args_opt
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
args = parse_args()
|
|
||||||
inception_v4_train()
|
inception_v4_train()
|
||||||
print('Inceptionv4 training success!')
|
print('Inceptionv4 training success!')
|
||||||
|
|
Loading…
Reference in New Issue