forked from mindspore-Ecosystem/mindspore
!15913 inceptionv4 support cpu training
From: @caojian05 Reviewed-by: @wuxuejian Signed-off-by: @wuxuejian
This commit is contained in:
commit
23fc8506d2
|
@ -35,7 +35,9 @@ void BatchNormCPUKernel::InitKernel(const CNodePtr &kernel_node) {
|
|||
is_train = AnfAlgo::GetNodeAttr<bool>(kernel_node, "is_training");
|
||||
momentum = AnfAlgo::GetNodeAttr<float>(kernel_node, "momentum");
|
||||
std::vector<size_t> x_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
|
||||
if (x_shape.size() != 4) {
|
||||
if (x_shape.size() == 2) {
|
||||
x_shape.insert(x_shape.end(), 2, 1);
|
||||
} else if (x_shape.size() != 4) {
|
||||
MS_LOG(EXCEPTION) << "Batchnorm only support nchw input!";
|
||||
}
|
||||
batch_size = x_shape[0];
|
||||
|
|
|
@ -38,7 +38,9 @@ void BatchNormGradCPUKernel::InitInputOutputSize(const CNodePtr &kernel_node) {
|
|||
void BatchNormGradCPUKernel::InitKernel(const CNodePtr &kernel_node) {
|
||||
MS_EXCEPTION_IF_NULL(kernel_node);
|
||||
std::vector<size_t> x_shape = AnfAlgo::GetInputDeviceShape(kernel_node, 0);
|
||||
if (x_shape.size() != 4) {
|
||||
if (x_shape.size() == 2) {
|
||||
x_shape.insert(x_shape.end(), 2, 1);
|
||||
} else if (x_shape.size() != 4) {
|
||||
MS_LOG(EXCEPTION) << "Fused batchnorm only support nchw input!";
|
||||
}
|
||||
batch_size = x_shape[0];
|
||||
|
|
|
@ -71,6 +71,8 @@ For FP16 operators, if the input data type is FP32, the backend of MindSpore wil
|
|||
├─scripts
|
||||
├─run_distribute_train_gpu.sh # launch distributed training with gpu platform(8p)
|
||||
├─run_eval_gpu.sh # launch evaluating with gpu platform
|
||||
├─run_eval_cpu.sh # launch evaluating with cpu platform
|
||||
├─run_standalone_train_cpu.sh # launch standalone training with cpu platform(1p)
|
||||
├─run_standalone_train_ascend.sh # launch standalone training with ascend platform(1p)
|
||||
├─run_distribute_train_ascend.sh # launch distributed training with ascend platform(8p)
|
||||
├─run_infer_310.sh # shell script for 310 inference
|
||||
|
@ -138,6 +140,13 @@ sh scripts/run_standalone_train_ascend.sh DEVICE_ID DATA_DIR
|
|||
sh scripts/run_distribute_train_gpu.sh DATA_PATH
|
||||
```
|
||||
|
||||
- CPU:
|
||||
|
||||
```bash
|
||||
# standalone training example with shell
|
||||
sh scripts/run_standalone_train_cpu.sh DATA_PATH
|
||||
```
|
||||
|
||||
### Launch
|
||||
|
||||
```bash
|
||||
|
@ -151,6 +160,9 @@ sh scripts/run_distribute_train_gpu.sh DATA_PATH
|
|||
GPU:
|
||||
# distribute training example(8p)
|
||||
sh scripts/run_distribute_train_gpu.sh DATA_PATH
|
||||
CPU:
|
||||
# standalone training example with shell
|
||||
sh scripts/run_standalone_train_cpu.sh DATA_PATH
|
||||
```
|
||||
|
||||
### Result
|
||||
|
|
|
@ -18,23 +18,35 @@ import os
|
|||
|
||||
import mindspore.nn as nn
|
||||
from mindspore import context
|
||||
from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits
|
||||
from mindspore.train.model import Model
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits
|
||||
|
||||
from src.dataset import create_dataset
|
||||
from src.config import config_ascend, config_gpu, config_cpu
|
||||
from src.dataset import create_dataset_imagenet, create_dataset_cifar10
|
||||
from src.inceptionv4 import Inceptionv4
|
||||
from src.config import config
|
||||
|
||||
CFG_DICT = {
|
||||
"Ascend": config_ascend,
|
||||
"GPU": config_gpu,
|
||||
"CPU": config_cpu,
|
||||
}
|
||||
|
||||
DS_DICT = {
|
||||
"imagenet": create_dataset_imagenet,
|
||||
"cifar10": create_dataset_cifar10,
|
||||
}
|
||||
|
||||
|
||||
def parse_args():
|
||||
'''parse_args'''
|
||||
parser = argparse.ArgumentParser(description='image classification evaluation')
|
||||
parser.add_argument('--platform', type=str, default='Ascend', choices=('Ascend', 'GPU'), help='run platform')
|
||||
parser.add_argument('--platform', type=str, default='Ascend', choices=('Ascend', 'GPU', 'CPU'), help='run platform')
|
||||
parser.add_argument('--dataset_path', type=str, default='', help='Dataset path')
|
||||
parser.add_argument('--checkpoint_path', type=str, default='', help='checkpoint of inceptionV4')
|
||||
args_opt = parser.parse_args()
|
||||
return args_opt
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
args = parse_args()
|
||||
|
||||
|
@ -42,18 +54,22 @@ if __name__ == '__main__':
|
|||
device_id = int(os.getenv('DEVICE_ID', '0'))
|
||||
context.set_context(device_id=device_id)
|
||||
|
||||
config = CFG_DICT[args.platform]
|
||||
create_dataset = DS_DICT[config.ds_type]
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=args.platform)
|
||||
net = Inceptionv4(classes=config.num_classes)
|
||||
ckpt = load_checkpoint(args.checkpoint_path)
|
||||
load_param_into_net(net, ckpt)
|
||||
net.set_train(False)
|
||||
dataset = create_dataset(dataset_path=args.dataset_path, do_train=False,
|
||||
repeat_num=1, batch_size=config.batch_size)
|
||||
config.rank = 0
|
||||
config.group_size = 1
|
||||
dataset = create_dataset(dataset_path=args.dataset_path, do_train=False, cfg=config)
|
||||
loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean")
|
||||
eval_metrics = {'Loss': nn.Loss(),
|
||||
'Top1-Acc': nn.Top1CategoricalAccuracy(),
|
||||
'Top5-Acc': nn.Top5CategoricalAccuracy()}
|
||||
model = Model(net, loss, optimizer=None, metrics=eval_metrics)
|
||||
print('=' * 20, 'Evalute start', '=' * 20)
|
||||
metrics = model.eval(dataset)
|
||||
metrics = model.eval(dataset, dataset_sink_mode=config.ds_sink_mode)
|
||||
print("metric: ", metrics)
|
||||
|
|
|
@ -0,0 +1,28 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
rm -rf evaluation
|
||||
mkdir evaluation
|
||||
cp ./*.py ./evaluation
|
||||
cp -r ./src ./evaluation
|
||||
cd ./evaluation || exit
|
||||
|
||||
DATA_DIR=$1
|
||||
CKPT_DIR=$2
|
||||
|
||||
echo "start evaluation"
|
||||
|
||||
python eval.py --dataset_path=$DATA_DIR --checkpoint_path=$CKPT_DIR --platform='CPU' > eval.log 2>&1 &
|
|
@ -0,0 +1,25 @@
|
|||
#!/bin/bash
|
||||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
DATA_DIR=$1
|
||||
|
||||
rm -rf train_standalone
|
||||
mkdir ./train_standalone
|
||||
cd ./train_standalone || exit
|
||||
env > env.log
|
||||
python -u ../train.py \
|
||||
--dataset_path=$DATA_DIR --platform=CPU> log.txt 2>&1 &
|
||||
cd ../
|
|
@ -17,13 +17,15 @@ network config setting, will be used in main.py
|
|||
"""
|
||||
from easydict import EasyDict as edict
|
||||
|
||||
config = edict({
|
||||
config_ascend = edict({
|
||||
'is_save_on_master': False,
|
||||
|
||||
'batch_size': 128,
|
||||
'epoch_size': 250,
|
||||
'num_classes': 1000,
|
||||
'work_nums': 8,
|
||||
'ds_type': 'imagenet',
|
||||
'ds_sink_mode': True,
|
||||
|
||||
'loss_scale': 1024,
|
||||
'smooth_factor': 0.1,
|
||||
|
@ -42,3 +44,57 @@ config = edict({
|
|||
'warmup_epochs': 1,
|
||||
'start_epoch': 1,
|
||||
})
|
||||
|
||||
config_gpu = edict({
|
||||
'is_save_on_master': False,
|
||||
|
||||
'batch_size': 128,
|
||||
'epoch_size': 250,
|
||||
'num_classes': 1000,
|
||||
'work_nums': 8,
|
||||
'ds_type': 'imagenet',
|
||||
'ds_sink_mode': True,
|
||||
|
||||
'loss_scale': 1024,
|
||||
'smooth_factor': 0.1,
|
||||
'weight_decay': 0.00004,
|
||||
'momentum': 0.9,
|
||||
'amp_level': 'O0',
|
||||
'decay': 0.9,
|
||||
'epsilon': 1.0,
|
||||
|
||||
'keep_checkpoint_max': 10,
|
||||
'save_checkpoint_epochs': 10,
|
||||
|
||||
'lr_init': 0.00004,
|
||||
'lr_end': 0.000004,
|
||||
'lr_max': 0.4,
|
||||
'warmup_epochs': 1,
|
||||
'start_epoch': 1,
|
||||
})
|
||||
|
||||
config_cpu = edict({
|
||||
'batch_size': 128,
|
||||
'epoch_size': 250,
|
||||
'num_classes': 10,
|
||||
'work_nums': 8,
|
||||
'ds_type': 'cifar10',
|
||||
'ds_sink_mode': False,
|
||||
|
||||
'loss_scale': 1024,
|
||||
'smooth_factor': 0.1,
|
||||
'weight_decay': 0.00004,
|
||||
'momentum': 0.9,
|
||||
'amp_level': 'O0',
|
||||
'decay': 0.9,
|
||||
'epsilon': 1.0,
|
||||
|
||||
'keep_checkpoint_max': 10,
|
||||
'save_checkpoint_epochs': 10,
|
||||
|
||||
'lr_init': 0.00004,
|
||||
'lr_end': 0.000004,
|
||||
'lr_max': 0.4,
|
||||
'warmup_epochs': 1,
|
||||
'start_epoch': 1,
|
||||
})
|
||||
|
|
|
@ -12,68 +12,100 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""Create train or eval dataset."""
|
||||
import os
|
||||
import mindspore.common.dtype as mstype
|
||||
import mindspore.dataset as de
|
||||
import mindspore.dataset.vision.c_transforms as C
|
||||
import mindspore.dataset.transforms.c_transforms as C2
|
||||
from src.config import config
|
||||
|
||||
|
||||
device_id = int(os.getenv('DEVICE_ID', '0'))
|
||||
device_num = int(os.getenv('RANK_SIZE', '1'))
|
||||
|
||||
|
||||
def create_dataset(dataset_path, do_train, repeat_num=1, batch_size=32, shard_id=0):
|
||||
"""
|
||||
Create a train or eval dataset.
|
||||
Data operations, will be used in train.py and eval.py
|
||||
"""
|
||||
import os
|
||||
|
||||
import mindspore.common.dtype as mstype
|
||||
import mindspore.dataset as ds
|
||||
import mindspore.dataset.transforms.c_transforms as C2
|
||||
import mindspore.dataset.vision.c_transforms as C
|
||||
|
||||
|
||||
def create_dataset_imagenet(dataset_path, do_train, cfg, repeat_num=1):
|
||||
"""
|
||||
create a train or eval dataset
|
||||
|
||||
Args:
|
||||
dataset_path (str): The path of dataset.
|
||||
do_train (bool): Whether dataset is used for train or eval.
|
||||
repeat_num (int): The repeat times of dataset. Default: 1.
|
||||
batch_size (int): The batch size of dataset. Default: 32.
|
||||
dataset_path(string): the path of dataset.
|
||||
do_train(bool): whether dataset is used for train or eval.
|
||||
cfg (dict): the config for creating dataset.
|
||||
repeat_num(int): the repeat times of dataset. Default: 1.
|
||||
|
||||
Returns:
|
||||
Dataset.
|
||||
dataset
|
||||
"""
|
||||
|
||||
do_shuffle = bool(do_train)
|
||||
|
||||
if device_num == 1 or not do_train:
|
||||
ds = de.ImageFolderDataset(dataset_path, num_parallel_workers=config.work_nums, shuffle=do_shuffle)
|
||||
if cfg.group_size == 1:
|
||||
data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=cfg.work_nums, shuffle=do_train)
|
||||
else:
|
||||
ds = de.ImageFolderDataset(dataset_path, num_parallel_workers=config.work_nums,
|
||||
shuffle=do_shuffle, num_shards=device_num, shard_id=shard_id)
|
||||
|
||||
image_length = 299
|
||||
data_set = ds.ImageFolderDataset(dataset_path, num_parallel_workers=cfg.work_nums, shuffle=do_train,
|
||||
num_shards=cfg.group_size, shard_id=cfg.rank)
|
||||
# define map operations
|
||||
size = 299
|
||||
if do_train:
|
||||
trans = [
|
||||
C.RandomCropDecodeResize(image_length, scale=(0.08, 1.0), ratio=(0.75, 1.333)),
|
||||
C.RandomCropDecodeResize(size, scale=(0.08, 1.0), ratio=(0.75, 1.333)),
|
||||
C.RandomHorizontalFlip(prob=0.5),
|
||||
C.RandomColorAdjust(brightness=0.4, contrast=0.4, saturation=0.4)
|
||||
]
|
||||
else:
|
||||
trans = [
|
||||
C.Decode(),
|
||||
C.Resize(image_length),
|
||||
C.CenterCrop(image_length)
|
||||
C.Resize(size),
|
||||
C.CenterCrop(size)
|
||||
]
|
||||
trans += [
|
||||
C.Rescale(1.0 / 255.0, 0.0),
|
||||
C.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
||||
C.HWC2CHW()
|
||||
]
|
||||
type_cast_op = C2.TypeCast(mstype.int32)
|
||||
data_set = data_set.map(operations=trans, input_columns="image", num_parallel_workers=cfg.work_nums)
|
||||
data_set = data_set.map(operations=type_cast_op, input_columns="label", num_parallel_workers=cfg.work_nums)
|
||||
# apply batch operations
|
||||
data_set = data_set.batch(cfg.batch_size, drop_remainder=True)
|
||||
# apply dataset repeat operation
|
||||
data_set = data_set.repeat(repeat_num)
|
||||
return data_set
|
||||
|
||||
|
||||
def create_dataset_cifar10(dataset_path, do_train, cfg, repeat_num=1):
|
||||
"""
|
||||
create a train or eval dataset
|
||||
|
||||
Args:
|
||||
dataset_path(string): the path of dataset.
|
||||
do_train(bool): whether dataset is used for train or eval.
|
||||
cfg (dict): the config for creating dataset.
|
||||
repeat_num(int): the repeat times of dataset. Default: 1.
|
||||
|
||||
Returns:
|
||||
dataset
|
||||
"""
|
||||
dataset_path = os.path.join(dataset_path, "cifar-10-batches-bin" if do_train else "cifar-10-verify-bin")
|
||||
if cfg.group_size == 1:
|
||||
data_set = ds.Cifar10Dataset(dataset_path, num_parallel_workers=cfg.work_nums, shuffle=do_train)
|
||||
else:
|
||||
data_set = ds.Cifar10Dataset(dataset_path, num_parallel_workers=cfg.work_nums, shuffle=do_train,
|
||||
num_shards=cfg.group_size, shard_id=cfg.rank)
|
||||
|
||||
# define map operations
|
||||
trans = []
|
||||
if do_train:
|
||||
trans.append(C.RandomCrop((32, 32), (4, 4, 4, 4)))
|
||||
trans.append(C.RandomHorizontalFlip(prob=0.5))
|
||||
|
||||
trans.append(C.Resize((299, 299)))
|
||||
trans.append(C.Rescale(1.0 / 255.0, 0.0))
|
||||
trans.append(C.Normalize(mean=[0.4914, 0.4822, 0.4465], std=[0.2023, 0.1994, 0.2010]))
|
||||
trans.append(C.HWC2CHW())
|
||||
|
||||
type_cast_op = C2.TypeCast(mstype.int32)
|
||||
|
||||
ds = ds.map(input_columns="label", operations=type_cast_op, num_parallel_workers=config.work_nums)
|
||||
ds = ds.map(input_columns="image", operations=trans, num_parallel_workers=config.work_nums)
|
||||
|
||||
data_set = data_set.map(operations=trans, input_columns="image", num_parallel_workers=cfg.work_nums)
|
||||
data_set = data_set.map(operations=type_cast_op, input_columns="label", num_parallel_workers=cfg.work_nums)
|
||||
# apply batch operations
|
||||
ds = ds.batch(batch_size, drop_remainder=True)
|
||||
|
||||
data_set = data_set.batch(cfg.batch_size, drop_remainder=do_train)
|
||||
# apply dataset repeat operation
|
||||
ds = ds.repeat(repeat_num)
|
||||
return ds
|
||||
data_set = data_set.repeat(repeat_num)
|
||||
return data_set
|
||||
|
|
|
@ -13,32 +13,63 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""train imagenet"""
|
||||
import os
|
||||
import argparse
|
||||
import math
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
|
||||
from mindspore.communication import init, get_rank
|
||||
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMonitor, LossMonitor
|
||||
from mindspore.train.model import ParallelMode
|
||||
from mindspore.train.loss_scale_manager import FixedLossScaleManager
|
||||
from mindspore import Model
|
||||
from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits
|
||||
from mindspore.nn import RMSProp
|
||||
from mindspore import Tensor
|
||||
from mindspore import context
|
||||
from mindspore.common import set_seed
|
||||
from mindspore.common.initializer import XavierUniform, initializer
|
||||
from mindspore.communication import init, get_rank, get_group_size
|
||||
from mindspore.nn import RMSProp
|
||||
from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits
|
||||
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, TimeMonitor, LossMonitor
|
||||
from mindspore.train.loss_scale_manager import FixedLossScaleManager
|
||||
from mindspore.train.model import ParallelMode
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
|
||||
from src.config import config_ascend, config_gpu, config_cpu
|
||||
from src.dataset import create_dataset_imagenet, create_dataset_cifar10
|
||||
from src.inceptionv4 import Inceptionv4
|
||||
from src.dataset import create_dataset, device_num
|
||||
|
||||
from src.config import config
|
||||
|
||||
os.environ['PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION'] = 'python'
|
||||
set_seed(1)
|
||||
|
||||
CFG_DICT = {
|
||||
"Ascend": config_ascend,
|
||||
"GPU": config_gpu,
|
||||
"CPU": config_cpu,
|
||||
}
|
||||
|
||||
DS_DICT = {
|
||||
"imagenet": create_dataset_imagenet,
|
||||
"cifar10": create_dataset_cifar10,
|
||||
}
|
||||
|
||||
device_num = int(os.getenv('RANK_SIZE', '1'))
|
||||
|
||||
|
||||
def parse_args():
|
||||
'''parse_args'''
|
||||
arg_parser = argparse.ArgumentParser(description='InceptionV4 image classification training')
|
||||
arg_parser.add_argument('--dataset_path', type=str, default='', help='Dataset path')
|
||||
arg_parser.add_argument('--device_id', type=int, default=0, help='device id')
|
||||
arg_parser.add_argument('--platform', type=str, default='Ascend', choices=("Ascend", "GPU", "CPU"),
|
||||
help='Platform, support Ascend, GPU, CPU.')
|
||||
arg_parser.add_argument('--resume', type=str, default='', help='resume training with existed checkpoint')
|
||||
args_opt = arg_parser.parse_args()
|
||||
return args_opt
|
||||
|
||||
|
||||
args = parse_args()
|
||||
|
||||
config = CFG_DICT[args.platform]
|
||||
create_dataset = DS_DICT[config.ds_type]
|
||||
|
||||
|
||||
def generate_cosine_lr(steps_per_epoch, total_epochs,
|
||||
lr_init=config.lr_init,
|
||||
lr_end=config.lr_end,
|
||||
|
@ -87,7 +118,6 @@ def inception_v4_train():
|
|||
context.set_context(device_id=args.device_id)
|
||||
context.set_context(enable_graph_kernel=False)
|
||||
|
||||
rank = 0
|
||||
if device_num > 1:
|
||||
if args.platform == "Ascend":
|
||||
init(backend_name='hccl')
|
||||
|
@ -96,15 +126,18 @@ def inception_v4_train():
|
|||
else:
|
||||
raise ValueError("Unsupported device target.")
|
||||
|
||||
rank = get_rank()
|
||||
config.rank = get_rank()
|
||||
config.group_size = get_group_size()
|
||||
context.set_auto_parallel_context(device_num=device_num,
|
||||
parallel_mode=ParallelMode.DATA_PARALLEL,
|
||||
gradients_mean=True,
|
||||
all_reduce_fusion_config=[200, 400])
|
||||
else:
|
||||
config.rank = 0
|
||||
config.group_size = 1
|
||||
|
||||
# create dataset
|
||||
train_dataset = create_dataset(dataset_path=args.dataset_path, do_train=True,
|
||||
repeat_num=1, batch_size=config.batch_size, shard_id=rank)
|
||||
train_dataset = create_dataset(dataset_path=args.dataset_path, do_train=True, cfg=config)
|
||||
train_step_size = train_dataset.get_dataset_size()
|
||||
|
||||
# create model
|
||||
|
@ -140,23 +173,16 @@ def inception_v4_train():
|
|||
|
||||
loss_scale_manager = FixedLossScaleManager(config.loss_scale, drop_overflow_update=False)
|
||||
|
||||
|
||||
if args.platform == "Ascend":
|
||||
model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc', 'top_1_accuracy', 'top_5_accuracy'},
|
||||
loss_scale_manager=loss_scale_manager, amp_level=config.amp_level)
|
||||
elif args.platform == "GPU":
|
||||
model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc', 'top_1_accuracy', 'top_5_accuracy'},
|
||||
loss_scale_manager=loss_scale_manager, amp_level='O0')
|
||||
else:
|
||||
raise ValueError("Unsupported device target.")
|
||||
|
||||
# define callbacks
|
||||
performance_cb = TimeMonitor(data_size=train_step_size)
|
||||
loss_cb = LossMonitor(per_print_times=train_step_size)
|
||||
ckp_save_step = config.save_checkpoint_epochs * train_step_size
|
||||
config_ck = CheckpointConfig(save_checkpoint_steps=ckp_save_step, keep_checkpoint_max=config.keep_checkpoint_max)
|
||||
ckpoint_cb = ModelCheckpoint(prefix=f"inceptionV4-train-rank{rank}",
|
||||
directory='ckpts_rank_' + str(rank), config=config_ck)
|
||||
ckpoint_cb = ModelCheckpoint(prefix=f"inceptionV4-train-rank{config.rank}",
|
||||
directory='ckpts_rank_' + str(config.rank), config=config_ck)
|
||||
callbacks = [performance_cb, loss_cb]
|
||||
if device_num > 1 and config.is_save_on_master:
|
||||
if args.device_id == 0:
|
||||
|
@ -165,21 +191,9 @@ def inception_v4_train():
|
|||
callbacks.append(ckpoint_cb)
|
||||
|
||||
# train model
|
||||
model.train(config.epoch_size, train_dataset, callbacks=callbacks, dataset_sink_mode=True)
|
||||
|
||||
def parse_args():
|
||||
'''parse_args'''
|
||||
arg_parser = argparse.ArgumentParser(description='InceptionV4 image classification training')
|
||||
arg_parser.add_argument('--dataset_path', type=str, default='', help='Dataset path')
|
||||
arg_parser.add_argument('--device_id', type=int, default=0, help='device id')
|
||||
arg_parser.add_argument('--platform', type=str, default='Ascend', choices=("Ascend", "GPU"),
|
||||
help='Platform, support Ascend, GPU.')
|
||||
arg_parser.add_argument('--resume', type=str, default='', help='resume training with existed checkpoint')
|
||||
args_opt = arg_parser.parse_args()
|
||||
return args_opt
|
||||
model.train(config.epoch_size, train_dataset, callbacks=callbacks, dataset_sink_mode=config.ds_sink_mode)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
args = parse_args()
|
||||
inception_v4_train()
|
||||
print('Inceptionv4 training success!')
|
||||
|
|
Loading…
Reference in New Issue