!15913 inceptionv4 support cpu training

From: @caojian05
Reviewed-by: @wuxuejian
Signed-off-by: @wuxuejian
This commit is contained in:
mindspore-ci-bot 2021-05-12 15:57:32 +08:00 committed by Gitee
commit 23fc8506d2
9 changed files with 275 additions and 88 deletions

View File

@ -35,7 +35,9 @@ void BatchNormCPUKernel::InitKernel(const CNodePtr &kernel_node) {
is_train = AnfAlgo::GetNodeAttr<bool>(kernel_node, "is_training"); is_train = AnfAlgo::GetNodeAttr<bool>(kernel_node, "is_training");
momentum = AnfAlgo::GetNodeAttr<float>(kernel_node, "momentum"); momentum = AnfAlgo::GetNodeAttr<float>(kernel_node, "momentum");
std::vector<size_t> x_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0); std::vector<size_t> x_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
if (x_shape.size() != 4) { if (x_shape.size() == 2) {
x_shape.insert(x_shape.end(), 2, 1);
} else if (x_shape.size() != 4) {
MS_LOG(EXCEPTION) << "Batchnorm only support nchw input!"; MS_LOG(EXCEPTION) << "Batchnorm only support nchw input!";
} }
batch_size = x_shape[0]; batch_size = x_shape[0];

View File

@ -38,7 +38,9 @@ void BatchNormGradCPUKernel::InitInputOutputSize(const CNodePtr &kernel_node) {
void BatchNormGradCPUKernel::InitKernel(const CNodePtr &kernel_node) { void BatchNormGradCPUKernel::InitKernel(const CNodePtr &kernel_node) {
MS_EXCEPTION_IF_NULL(kernel_node); MS_EXCEPTION_IF_NULL(kernel_node);
std::vector<size_t> x_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0); std::vector<size_t> x_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
if (x_shape.size() != 4) { if (x_shape.size() == 2) {
x_shape.insert(x_shape.end(), 2, 1);
} else if (x_shape.size() != 4) {
MS_LOG(EXCEPTION) << "Fused batchnorm only support nchw input!"; MS_LOG(EXCEPTION) << "Fused batchnorm only support nchw input!";
} }
batch_size = x_shape[0]; batch_size = x_shape[0];

View File

@ -71,6 +71,8 @@ For FP16 operators, if the input data type is FP32, the backend of MindSpore wil
├─scripts ├─scripts
├─run_distribute_train_gpu.sh # launch distributed training with gpu platform(8p) ├─run_distribute_train_gpu.sh # launch distributed training with gpu platform(8p)
├─run_eval_gpu.sh # launch evaluating with gpu platform ├─run_eval_gpu.sh # launch evaluating with gpu platform
├─run_eval_cpu.sh # launch evaluating with cpu platform
├─run_standalone_train_cpu.sh # launch standalone training with cpu platform(1p)
├─run_standalone_train_ascend.sh # launch standalone training with ascend platform(1p) ├─run_standalone_train_ascend.sh # launch standalone training with ascend platform(1p)
├─run_distribute_train_ascend.sh # launch distributed training with ascend platform(8p) ├─run_distribute_train_ascend.sh # launch distributed training with ascend platform(8p)
├─run_infer_310.sh # shell script for 310 inference ├─run_infer_310.sh # shell script for 310 inference
@ -138,6 +140,13 @@ sh scripts/run_standalone_train_ascend.sh DEVICE_ID DATA_DIR
sh scripts/run_distribute_train_gpu.sh DATA_PATH sh scripts/run_distribute_train_gpu.sh DATA_PATH
``` ```
- CPU:
```bash
# standalone training example with shell
sh scripts/run_standalone_train_cpu.sh DATA_PATH
```
### Launch ### Launch
```bash ```bash
@ -151,6 +160,9 @@ sh scripts/run_distribute_train_gpu.sh DATA_PATH
GPU: GPU:
# distribute training example(8p) # distribute training example(8p)
sh scripts/run_distribute_train_gpu.sh DATA_PATH sh scripts/run_distribute_train_gpu.sh DATA_PATH
CPU:
# standalone training example with shell
sh scripts/run_standalone_train_cpu.sh DATA_PATH
``` ```
### Result ### Result

View File

@ -18,23 +18,35 @@ import os
import mindspore.nn as nn import mindspore.nn as nn
from mindspore import context from mindspore import context
from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits
from mindspore.train.model import Model from mindspore.train.model import Model
from mindspore.train.serialization import load_checkpoint, load_param_into_net from mindspore.train.serialization import load_checkpoint, load_param_into_net
from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits from src.config import config_ascend, config_gpu, config_cpu
from src.dataset import create_dataset_imagenet, create_dataset_cifar10
from src.dataset import create_dataset
from src.inceptionv4 import Inceptionv4 from src.inceptionv4 import Inceptionv4
from src.config import config
CFG_DICT = {
"Ascend": config_ascend,
"GPU": config_gpu,
"CPU": config_cpu,
}
DS_DICT = {
"imagenet": create_dataset_imagenet,
"cifar10": create_dataset_cifar10,
}
def parse_args(): def parse_args():
'''parse_args''' '''parse_args'''
parser = argparse.ArgumentParser(description='image classification evaluation') parser = argparse.ArgumentParser(description='image classification evaluation')
parser.add_argument('--platform', type=str, default='Ascend', choices=('Ascend', 'GPU'), help='run platform') parser.add_argument('--platform', type=str, default='Ascend', choices=('Ascend', 'GPU', 'CPU'), help='run platform')
parser.add_argument('--dataset_path', type=str, default='', help='Dataset path') parser.add_argument('--dataset_path', type=str, default='', help='Dataset path')
parser.add_argument('--checkpoint_path', type=str, default='', help='checkpoint of inceptionV4') parser.add_argument('--checkpoint_path', type=str, default='', help='checkpoint of inceptionV4')
args_opt = parser.parse_args() args_opt = parser.parse_args()
return args_opt return args_opt
if __name__ == '__main__': if __name__ == '__main__':
args = parse_args() args = parse_args()
@ -42,18 +54,22 @@ if __name__ == '__main__':
device_id = int(os.getenv('DEVICE_ID', '0')) device_id = int(os.getenv('DEVICE_ID', '0'))
context.set_context(device_id=device_id) context.set_context(device_id=device_id)
config = CFG_DICT[args.platform]
create_dataset = DS_DICT[config.ds_type]
context.set_context(mode=context.GRAPH_MODE, device_target=args.platform) context.set_context(mode=context.GRAPH_MODE, device_target=args.platform)
net = Inceptionv4(classes=config.num_classes) net = Inceptionv4(classes=config.num_classes)
ckpt = load_checkpoint(args.checkpoint_path) ckpt = load_checkpoint(args.checkpoint_path)
load_param_into_net(net, ckpt) load_param_into_net(net, ckpt)
net.set_train(False) net.set_train(False)
dataset = create_dataset(dataset_path=args.dataset_path, do_train=False, config.rank = 0
repeat_num=1, batch_size=config.batch_size) config.group_size = 1
dataset = create_dataset(dataset_path=args.dataset_path, do_train=False, cfg=config)
loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean") loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
eval_metrics = {'Loss': nn.Loss(), eval_metrics = {'Loss': nn.Loss(),
'Top1-Acc': nn.Top1CategoricalAccuracy(), 'Top1-Acc': nn.Top1CategoricalAccuracy(),
'Top5-Acc': nn.Top5CategoricalAccuracy()} 'Top5-Acc': nn.Top5CategoricalAccuracy()}
model = Model(net, loss, optimizer=None, metrics=eval_metrics) model = Model(net, loss, optimizer=None, metrics=eval_metrics)
print('=' * 20, 'Evalute start', '=' * 20) print('=' * 20, 'Evalute start', '=' * 20)
metrics = model.eval(dataset) metrics = model.eval(dataset, dataset_sink_mode=config.ds_sink_mode)
print("metric: ", metrics) print("metric: ", metrics)

View File

@ -0,0 +1,28 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
rm -rf evaluation
mkdir evaluation
cp ./*.py ./evaluation
cp -r ./src ./evaluation
cd ./evaluation || exit
DATA_DIR=$1
CKPT_DIR=$2
echo "start evaluation"
python eval.py --dataset_path=$DATA_DIR --checkpoint_path=$CKPT_DIR --platform='CPU' > eval.log 2>&1 &

View File

@ -0,0 +1,25 @@
#!/bin/bash
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
DATA_DIR=$1
rm -rf train_standalone
mkdir ./train_standalone
cd ./train_standalone || exit
env > env.log
python -u ../train.py \
--dataset_path=$DATA_DIR --platform=CPU> log.txt 2>&1 &
cd ../

View File

@ -17,13 +17,15 @@ network config setting, will be used in main.py
""" """
from easydict import EasyDict as edict from easydict import EasyDict as edict
config = edict({ config_ascend = edict({
'is_save_on_master': False, 'is_save_on_master': False,
'batch_size': 128, 'batch_size': 128,
'epoch_size': 250, 'epoch_size': 250,
'num_classes': 1000, 'num_classes': 1000,
'work_nums': 8, 'work_nums': 8,
'ds_type': 'imagenet',
'ds_sink_mode': True,
'loss_scale': 1024, 'loss_scale': 1024,
'smooth_factor': 0.1, 'smooth_factor': 0.1,
@ -42,3 +44,57 @@ config = edict({
'warmup_epochs': 1, 'warmup_epochs': 1,
'start_epoch': 1, 'start_epoch': 1,
}) })
config_gpu = edict({
'is_save_on_master': False,
'batch_size': 128,
'epoch_size': 250,
'num_classes': 1000,
'work_nums': 8,
'ds_type': 'imagenet',
'ds_sink_mode': True,
'loss_scale': 1024,
'smooth_factor': 0.1,
'weight_decay': 0.00004,
'momentum': 0.9,
'amp_level': 'O0',
'decay': 0.9,
'epsilon': 1.0,
'keep_checkpoint_max': 10,
'save_checkpoint_epochs': 10,
'lr_init': 0.00004,
'lr_end': 0.000004,
'lr_max': 0.4,
'warmup_epochs': 1,
'start_epoch': 1,
})
config_cpu = edict({
'batch_size': 128,
'epoch_size': 250,
'num_classes': 10,
'work_nums': 8,
'ds_type': 'cifar10',
'ds_sink_mode': False,
'loss_scale': 1024,
'smooth_factor': 0.1,
'weight_decay': 0.00004,
'momentum': 0.9,
'amp_level': 'O0',
'decay': 0.9,
'epsilon': 1.0,
'keep_checkpoint_max': 10,
'save_checkpoint_epochs': 10,
'lr_init': 0.00004,
'lr_end': 0.000004,
'lr_max': 0.4,
'warmup_epochs': 1,
'start_epoch': 1,
})

View File

@ -12,68 +12,100 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
"""Create train or eval dataset."""
import os
import mindspore.common.dtype as mstype
import mindspore.dataset as de
import mindspore.dataset.vision.c_transforms as C
import mindspore.dataset.transforms.c_transforms as C2
from src.config import config
device_id = int(os.getenv('DEVICE_ID', '0'))
device_num = int(os.getenv('RANK_SIZE', '1'))
def create_dataset(dataset_path, do_train, repeat_num=1, batch_size=32, shard_id=0):
""" """
Create a train or eval dataset. Data operations, will be used in train.py and eval.py
"""
import os
import mindspore.common.dtype as mstype
import mindspore.dataset as ds
import mindspore.dataset.transforms.c_transforms as C2
import mindspore.dataset.vision.c_transforms as C
def create_dataset_imagenet(dataset_path, do_train, cfg, repeat_num=1):
"""
create a train or eval dataset
Args: Args:
dataset_path (str): The path of dataset. dataset_path(string): the path of dataset.
do_train (bool): Whether dataset is used for train or eval. do_train(bool): whether dataset is used for train or eval.
repeat_num (int): The repeat times of dataset. Default: 1. cfg (dict): the config for creating dataset.
batch_size (int): The batch size of dataset. Default: 32. repeat_num(int): the repeat times of dataset. Default: 1.
Returns: Returns:
Dataset. dataset
""" """
if cfg.group_size == 1:
do_shuffle = bool(do_train) data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=cfg.work_nums, shuffle=do_train)
if device_num == 1 or not do_train:
ds = de.ImageFolderDataset(dataset_path, num_parallel_workers=config.work_nums, shuffle=do_shuffle)
else: else:
ds = de.ImageFolderDataset(dataset_path, num_parallel_workers=config.work_nums, data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=cfg.work_nums, shuffle=do_train,
shuffle=do_shuffle, num_shards=device_num, shard_id=shard_id) num_shards=cfg.group_size, shard_id=cfg.rank)
# define map operations
image_length = 299 size = 299
if do_train: if do_train:
trans = [ trans = [
C.RandomCropDecodeResize(image_length, scale=(0.08, 1.0), ratio=(0.75, 1.333)), C.RandomCropDecodeResize(size, scale=(0.08, 1.0), ratio=(0.75, 1.333)),
C.RandomHorizontalFlip(prob=0.5), C.RandomHorizontalFlip(prob=0.5),
C.RandomColorAdjust(brightness=0.4, contrast=0.4, saturation=0.4) C.RandomColorAdjust(brightness=0.4, contrast=0.4, saturation=0.4)
] ]
else: else:
trans = [ trans = [
C.Decode(), C.Decode(),
C.Resize(image_length), C.Resize(size),
C.CenterCrop(image_length) C.CenterCrop(size)
] ]
trans += [ trans += [
C.Rescale(1.0 / 255.0, 0.0), C.Rescale(1.0 / 255.0, 0.0),
C.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), C.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
C.HWC2CHW() C.HWC2CHW()
] ]
type_cast_op = C2.TypeCast(mstype.int32)
data_set = data_set.map(operations=trans, input_columns="image", num_parallel_workers=cfg.work_nums)
data_set = data_set.map(operations=type_cast_op, input_columns="label", num_parallel_workers=cfg.work_nums)
# apply batch operations
data_set = data_set.batch(cfg.batch_size, drop_remainder=True)
# apply dataset repeat operation
data_set = data_set.repeat(repeat_num)
return data_set
def create_dataset_cifar10(dataset_path, do_train, cfg, repeat_num=1):
"""
create a train or eval dataset
Args:
dataset_path(string): the path of dataset.
do_train(bool): whether dataset is used for train or eval.
cfg (dict): the config for creating dataset.
repeat_num(int): the repeat times of dataset. Default: 1.
Returns:
dataset
"""
dataset_path = os.path.join(dataset_path, "cifar-10-batches-bin" if do_train else "cifar-10-verify-bin")
if cfg.group_size == 1:
data_set = ds.Cifar10Dataset(dataset_path, num_parallel_workers=cfg.work_nums, shuffle=do_train)
else:
data_set = ds.Cifar10Dataset(dataset_path, num_parallel_workers=cfg.work_nums, shuffle=do_train,
num_shards=cfg.group_size, shard_id=cfg.rank)
# define map operations
trans = []
if do_train:
trans.append(C.RandomCrop((32, 32), (4, 4, 4, 4)))
trans.append(C.RandomHorizontalFlip(prob=0.5))
trans.append(C.Resize((299, 299)))
trans.append(C.Rescale(1.0 / 255.0, 0.0))
trans.append(C.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010]))
trans.append(C.HWC2CHW())
type_cast_op = C2.TypeCast(mstype.int32) type_cast_op = C2.TypeCast(mstype.int32)
data_set = data_set.map(operations=trans, input_columns="image", num_parallel_workers=cfg.work_nums)
ds = ds.map(input_columns="label", operations=type_cast_op, num_parallel_workers=config.work_nums) data_set = data_set.map(operations=type_cast_op, input_columns="label", num_parallel_workers=cfg.work_nums)
ds = ds.map(input_columns="image", operations=trans, num_parallel_workers=config.work_nums)
# apply batch operations # apply batch operations
ds = ds.batch(batch_size, drop_remainder=True) data_set = data_set.batch(cfg.batch_size, drop_remainder=do_train)
# apply dataset repeat operation # apply dataset repeat operation
ds = ds.repeat(repeat_num) data_set = data_set.repeat(repeat_num)
return ds return data_set

View File

@ -13,32 +13,63 @@
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
"""train imagenet""" """train imagenet"""
import os
import argparse import argparse
import math import math
import os
import numpy as np import numpy as np
from mindspore.communication import init, get_rank
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMonitor, LossMonitor
from mindspore.train.model import ParallelMode
from mindspore.train.loss_scale_manager import FixedLossScaleManager
from mindspore import Model from mindspore import Model
from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits
from mindspore.nn import RMSProp
from mindspore import Tensor from mindspore import Tensor
from mindspore import context from mindspore import context
from mindspore.common import set_seed from mindspore.common import set_seed
from mindspore.common.initializer import XavierUniform, initializer from mindspore.common.initializer import XavierUniform, initializer
from mindspore.communication import init, get_rank, get_group_size
from mindspore.nn import RMSProp
from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMonitor, LossMonitor
from mindspore.train.loss_scale_manager import FixedLossScaleManager
from mindspore.train.model import ParallelMode
from mindspore.train.serialization import load_checkpoint, load_param_into_net from mindspore.train.serialization import load_checkpoint, load_param_into_net
from src.config import config_ascend, config_gpu, config_cpu
from src.dataset import create_dataset_imagenet, create_dataset_cifar10
from src.inceptionv4 import Inceptionv4 from src.inceptionv4 import Inceptionv4
from src.dataset import create_dataset, device_num
from src.config import config
os.environ['PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION'] = 'python' os.environ['PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION'] = 'python'
set_seed(1) set_seed(1)
CFG_DICT = {
"Ascend": config_ascend,
"GPU": config_gpu,
"CPU": config_cpu,
}
DS_DICT = {
"imagenet": create_dataset_imagenet,
"cifar10": create_dataset_cifar10,
}
device_num = int(os.getenv('RANK_SIZE', '1'))
def parse_args():
'''parse_args'''
arg_parser = argparse.ArgumentParser(description='InceptionV4 image classification training')
arg_parser.add_argument('--dataset_path', type=str, default='', help='Dataset path')
arg_parser.add_argument('--device_id', type=int, default=0, help='device id')
arg_parser.add_argument('--platform', type=str, default='Ascend', choices=("Ascend", "GPU", "CPU"),
help='Platform, support Ascend, GPU, CPU.')
arg_parser.add_argument('--resume', type=str, default='', help='resume training with existed checkpoint')
args_opt = arg_parser.parse_args()
return args_opt
args = parse_args()
config = CFG_DICT[args.platform]
create_dataset = DS_DICT[config.ds_type]
def generate_cosine_lr(steps_per_epoch, total_epochs, def generate_cosine_lr(steps_per_epoch, total_epochs,
lr_init=config.lr_init, lr_init=config.lr_init,
lr_end=config.lr_end, lr_end=config.lr_end,
@ -87,7 +118,6 @@ def inception_v4_train():
context.set_context(device_id=args.device_id) context.set_context(device_id=args.device_id)
context.set_context(enable_graph_kernel=False) context.set_context(enable_graph_kernel=False)
rank = 0
if device_num > 1: if device_num > 1:
if args.platform == "Ascend": if args.platform == "Ascend":
init(backend_name='hccl') init(backend_name='hccl')
@ -96,15 +126,18 @@ def inception_v4_train():
else: else:
raise ValueError("Unsupported device target.") raise ValueError("Unsupported device target.")
rank = get_rank() config.rank = get_rank()
config.group_size = get_group_size()
context.set_auto_parallel_context(device_num=device_num, context.set_auto_parallel_context(device_num=device_num,
parallel_mode=ParallelMode.DATA_PARALLEL, parallel_mode=ParallelMode.DATA_PARALLEL,
gradients_mean=True, gradients_mean=True,
all_reduce_fusion_config=[200, 400]) all_reduce_fusion_config=[200, 400])
else:
config.rank = 0
config.group_size = 1
# create dataset # create dataset
train_dataset = create_dataset(dataset_path=args.dataset_path, do_train=True, train_dataset = create_dataset(dataset_path=args.dataset_path, do_train=True, cfg=config)
repeat_num=1, batch_size=config.batch_size, shard_id=rank)
train_step_size = train_dataset.get_dataset_size() train_step_size = train_dataset.get_dataset_size()
# create model # create model
@ -140,23 +173,16 @@ def inception_v4_train():
loss_scale_manager = FixedLossScaleManager(config.loss_scale, drop_overflow_update=False) loss_scale_manager = FixedLossScaleManager(config.loss_scale, drop_overflow_update=False)
if args.platform == "Ascend":
model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc', 'top_1_accuracy', 'top_5_accuracy'}, model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc', 'top_1_accuracy', 'top_5_accuracy'},
loss_scale_manager=loss_scale_manager, amp_level=config.amp_level) loss_scale_manager=loss_scale_manager, amp_level=config.amp_level)
elif args.platform == "GPU":
model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc', 'top_1_accuracy', 'top_5_accuracy'},
loss_scale_manager=loss_scale_manager, amp_level='O0')
else:
raise ValueError("Unsupported device target.")
# define callbacks # define callbacks
performance_cb = TimeMonitor(data_size=train_step_size) performance_cb = TimeMonitor(data_size=train_step_size)
loss_cb = LossMonitor(per_print_times=train_step_size) loss_cb = LossMonitor(per_print_times=train_step_size)
ckp_save_step = config.save_checkpoint_epochs * train_step_size ckp_save_step = config.save_checkpoint_epochs * train_step_size
config_ck = CheckpointConfig(save_checkpoint_steps=ckp_save_step, keep_checkpoint_max=config.keep_checkpoint_max) config_ck = CheckpointConfig(save_checkpoint_steps=ckp_save_step, keep_checkpoint_max=config.keep_checkpoint_max)
ckpoint_cb = ModelCheckpoint(prefix=f"inceptionV4-train-rank{rank}", ckpoint_cb = ModelCheckpoint(prefix=f"inceptionV4-train-rank{config.rank}",
directory='ckpts_rank_' + str(rank), config=config_ck) directory='ckpts_rank_' + str(config.rank), config=config_ck)
callbacks = [performance_cb, loss_cb] callbacks = [performance_cb, loss_cb]
if device_num > 1 and config.is_save_on_master: if device_num > 1 and config.is_save_on_master:
if args.device_id == 0: if args.device_id == 0:
@ -165,21 +191,9 @@ def inception_v4_train():
callbacks.append(ckpoint_cb) callbacks.append(ckpoint_cb)
# train model # train model
model.train(config.epoch_size, train_dataset, callbacks=callbacks, dataset_sink_mode=True) model.train(config.epoch_size, train_dataset, callbacks=callbacks, dataset_sink_mode=config.ds_sink_mode)
def parse_args():
'''parse_args'''
arg_parser = argparse.ArgumentParser(description='InceptionV4 image classification training')
arg_parser.add_argument('--dataset_path', type=str, default='', help='Dataset path')
arg_parser.add_argument('--device_id', type=int, default=0, help='device id')
arg_parser.add_argument('--platform', type=str, default='Ascend', choices=("Ascend", "GPU"),
help='Platform, support Ascend, GPU.')
arg_parser.add_argument('--resume', type=str, default='', help='resume training with existed checkpoint')
args_opt = arg_parser.parse_args()
return args_opt
if __name__ == '__main__': if __name__ == '__main__':
args = parse_args()
inception_v4_train() inception_v4_train()
print('Inceptionv4 training success!') print('Inceptionv4 training success!')