forked from OSSInnovation/mindspore
googlenet support imagenet dataset on Ascend
This commit is contained in:
parent
fc8bd0dd03
commit
145545a9cb
|
@ -25,35 +25,57 @@ from mindspore.train.model import Model
|
|||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from mindspore.common import set_seed
|
||||
|
||||
from src.config import cifar_cfg as cfg
|
||||
from src.dataset import create_dataset
|
||||
from src.config import cifar_cfg, imagenet_cfg
|
||||
from src.dataset import create_dataset_cifar10, create_dataset_imagenet
|
||||
|
||||
from src.googlenet import GoogleNet
|
||||
|
||||
set_seed(1)
|
||||
|
||||
parser = argparse.ArgumentParser(description='googlenet')
|
||||
parser.add_argument('--dataset_name', type=str, default='cifar10', choices=['imagenet', 'cifar10'],
|
||||
help='dataset name.')
|
||||
parser.add_argument('--checkpoint_path', type=str, default=None, help='Checkpoint file path')
|
||||
args_opt = parser.parse_args()
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
if args_opt.dataset_name == 'cifar10':
|
||||
cfg = cifar_cfg
|
||||
dataset = create_dataset_cifar10(cfg.data_path, 1, False)
|
||||
loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean', is_grad=False)
|
||||
net = GoogleNet(num_classes=cfg.num_classes)
|
||||
opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 0.01, cfg.momentum,
|
||||
weight_decay=cfg.weight_decay)
|
||||
model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'})
|
||||
|
||||
elif args_opt.dataset_name == "imagenet":
|
||||
cfg = imagenet_cfg
|
||||
dataset = create_dataset_imagenet(cfg.val_data_path, 1, False)
|
||||
if not cfg.use_label_smooth:
|
||||
cfg.label_smooth_factor = 0.0
|
||||
loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean",
|
||||
smooth_factor=cfg.label_smooth_factor, num_classes=cfg.num_classes)
|
||||
net = GoogleNet(num_classes=cfg.num_classes)
|
||||
model = Model(net, loss_fn=loss, metrics={'top_1_accuracy', 'top_5_accuracy'})
|
||||
|
||||
else:
|
||||
raise ValueError("dataset is not support.")
|
||||
|
||||
device_target = cfg.device_target
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=cfg.device_target)
|
||||
if device_target == "Ascend":
|
||||
context.set_context(device_id=cfg.device_id)
|
||||
|
||||
net = GoogleNet(num_classes=cfg.num_classes)
|
||||
opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 0.01, cfg.momentum,
|
||||
weight_decay=cfg.weight_decay)
|
||||
loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
|
||||
model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'})
|
||||
|
||||
if device_target == "Ascend":
|
||||
param_dict = load_checkpoint(cfg.checkpoint_path)
|
||||
else: # GPU
|
||||
if args_opt.checkpoint_path is not None:
|
||||
param_dict = load_checkpoint(args_opt.checkpoint_path)
|
||||
print("load checkpoint from [{}].".format(args_opt.checkpoint_path))
|
||||
else:
|
||||
param_dict = load_checkpoint(cfg.checkpoint_path)
|
||||
print("load checkpoint from [{}].".format(cfg.checkpoint_path))
|
||||
|
||||
load_param_into_net(net, param_dict)
|
||||
net.set_train(False)
|
||||
dataset = create_dataset(cfg.data_path, 1, False)
|
||||
|
||||
acc = model.eval(dataset)
|
||||
print("accuracy: ", acc)
|
||||
|
|
|
@ -16,18 +16,32 @@
|
|||
##############export checkpoint file into air and onnx models#################
|
||||
python export.py
|
||||
"""
|
||||
import argparse
|
||||
import numpy as np
|
||||
|
||||
import mindspore as ms
|
||||
from mindspore import Tensor
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net, export
|
||||
|
||||
from src.config import cifar_cfg as cfg
|
||||
from src.config import cifar_cfg, imagenet_cfg
|
||||
from src.googlenet import GoogleNet
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description='Classification')
|
||||
parser.add_argument('--dataset_name', type=str, default='cifar10', choices=['imagenet', 'cifar10'],
|
||||
help='dataset name.')
|
||||
args_opt = parser.parse_args()
|
||||
|
||||
if args_opt.dataset_name == 'cifar10':
|
||||
cfg = cifar_cfg
|
||||
elif args_opt.dataset_name == 'imagenet':
|
||||
cfg = imagenet_cfg
|
||||
else:
|
||||
raise ValueError("dataset is not support.")
|
||||
|
||||
net = GoogleNet(num_classes=cfg.num_classes)
|
||||
|
||||
assert cfg.checkpoint_path is not None, "cfg.checkpoint_path is None."
|
||||
param_dict = load_checkpoint(cfg.checkpoint_path)
|
||||
load_param_into_net(net, param_dict)
|
||||
|
||||
|
|
|
@ -14,9 +14,9 @@
|
|||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
if [ $# != 1 ]
|
||||
if [ $# != 1 ] && [ $# != 2 ]
|
||||
then
|
||||
echo "Usage: sh run_train.sh [RANK_TABLE_FILE]"
|
||||
echo "Usage: sh run_train.sh [RANK_TABLE_FILE] [cifar10|imagenet]"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
|
@ -26,6 +26,19 @@ then
|
|||
exit 1
|
||||
fi
|
||||
|
||||
|
||||
dataset_type='cifar10'
|
||||
if [ $# == 2 ]
|
||||
then
|
||||
if [ $2 != "cifar10" ] && [ $2 != "imagenet" ]
|
||||
then
|
||||
echo "error: the selected dataset is neither cifar10 nor imagenet"
|
||||
exit 1
|
||||
fi
|
||||
dataset_type=$2
|
||||
fi
|
||||
|
||||
|
||||
ulimit -u unlimited
|
||||
export DEVICE_NUM=8
|
||||
export RANK_SIZE=8
|
||||
|
@ -43,9 +56,9 @@ do
|
|||
mkdir ./train_parallel$i
|
||||
cp -r ./src ./train_parallel$i
|
||||
cp ./train.py ./train_parallel$i
|
||||
echo "start training for rank $RANK_ID, device $DEVICE_ID"
|
||||
echo "start training for rank $RANK_ID, device $DEVICE_ID, $dataset_type"
|
||||
cd ./train_parallel$i ||exit
|
||||
env > env.log
|
||||
python train.py --device_id=$i > log 2>&1 &
|
||||
python train.py --device_id=$i --dataset_name=$dataset_type> log 2>&1 &
|
||||
cd ..
|
||||
done
|
||||
|
|
|
@ -18,6 +18,7 @@ network config setting, will be used in main.py
|
|||
from easydict import EasyDict as edict
|
||||
|
||||
cifar_cfg = edict({
|
||||
'name': 'cifar10',
|
||||
'pre_trained': False,
|
||||
'num_classes': 10,
|
||||
'lr_init': 0.1,
|
||||
|
@ -30,9 +31,45 @@ cifar_cfg = edict({
|
|||
'image_width': 224,
|
||||
'data_path': './cifar10',
|
||||
'device_target': 'Ascend',
|
||||
'device_id': 4,
|
||||
'device_id': 0,
|
||||
'keep_checkpoint_max': 10,
|
||||
'checkpoint_path': './train_googlenet_cifar10-125_390.ckpt',
|
||||
'onnx_filename': 'googlenet.onnx',
|
||||
'air_filename': 'googlenet.air'
|
||||
})
|
||||
|
||||
imagenet_cfg = edict({
|
||||
'name': 'imagenet',
|
||||
'pre_trained': False,
|
||||
'num_classes': 1000,
|
||||
'lr_init': 0.1,
|
||||
'batch_size': 256,
|
||||
'epoch_size': 300,
|
||||
'momentum': 0.9,
|
||||
'weight_decay': 1e-4,
|
||||
'buffer_size': None, # invalid parameter
|
||||
'image_height': 224,
|
||||
'image_width': 224,
|
||||
'data_path': './ImageNet_Original/train/',
|
||||
'val_data_path': './ImageNet_Original/val/',
|
||||
'device_target': 'Ascend',
|
||||
'device_id': 0,
|
||||
'keep_checkpoint_max': 10,
|
||||
'checkpoint_path': None,
|
||||
'onnx_filename': 'googlenet.onnx',
|
||||
'air_filename': 'googlenet.air',
|
||||
|
||||
# optimizer and lr related
|
||||
'lr_scheduler': 'exponential',
|
||||
'lr_epochs': [70, 140, 210, 280],
|
||||
'lr_gamma': 0.3,
|
||||
'eta_min': 0.0,
|
||||
'T_max': 150,
|
||||
'warmup_epochs': 0,
|
||||
|
||||
# loss related
|
||||
'is_dynamic_loss_scale': 0,
|
||||
'loss_scale': 1024,
|
||||
'label_smooth_factor': 0.1,
|
||||
'use_label_smooth': True,
|
||||
})
|
||||
|
|
|
@ -21,10 +21,10 @@ import mindspore.common.dtype as mstype
|
|||
import mindspore.dataset as ds
|
||||
import mindspore.dataset.transforms.c_transforms as C
|
||||
import mindspore.dataset.vision.c_transforms as vision
|
||||
from src.config import cifar_cfg as cfg
|
||||
from src.config import cifar_cfg, imagenet_cfg
|
||||
|
||||
|
||||
def create_dataset(data_home, repeat_num=1, training=True):
|
||||
def create_dataset_cifar10(data_home, repeat_num=1, training=True):
|
||||
"""Data operations."""
|
||||
ds.config.set_seed(1)
|
||||
data_dir = os.path.join(data_home, "cifar-10-batches-bin")
|
||||
|
@ -37,14 +37,14 @@ def create_dataset(data_home, repeat_num=1, training=True):
|
|||
else:
|
||||
data_set = ds.Cifar10Dataset(data_dir, num_shards=rank_size, shard_id=rank_id, shuffle=False)
|
||||
|
||||
resize_height = cfg.image_height
|
||||
resize_width = cfg.image_width
|
||||
resize_height = cifar_cfg.image_height
|
||||
resize_width = cifar_cfg.image_width
|
||||
|
||||
# define map operations
|
||||
random_crop_op = vision.RandomCrop((32, 32), (4, 4, 4, 4)) # padding_mode default CONSTANT
|
||||
random_horizontal_op = vision.RandomHorizontalFlip()
|
||||
resize_op = vision.Resize((resize_height, resize_width)) # interpolation default BILINEAR
|
||||
rescale_op = vision.Rescale(1.0/255.0, 0.0)
|
||||
rescale_op = vision.Rescale(1.0 / 255.0, 0.0)
|
||||
normalize_op = vision.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))
|
||||
changeswap_op = vision.HWC2CHW()
|
||||
type_cast_op = C.TypeCast(mstype.int32)
|
||||
|
@ -59,7 +59,7 @@ def create_dataset(data_home, repeat_num=1, training=True):
|
|||
data_set = data_set.map(input_columns="image", operations=c_trans)
|
||||
|
||||
# apply batch operations
|
||||
data_set = data_set.batch(batch_size=cfg.batch_size, drop_remainder=True)
|
||||
data_set = data_set.batch(batch_size=cifar_cfg.batch_size, drop_remainder=True)
|
||||
|
||||
# apply repeat operations
|
||||
data_set = data_set.repeat(repeat_num)
|
||||
|
@ -67,6 +67,67 @@ def create_dataset(data_home, repeat_num=1, training=True):
|
|||
return data_set
|
||||
|
||||
|
||||
def create_dataset_imagenet(dataset_path, repeat_num=1, training=True,
|
||||
num_parallel_workers=None, shuffle=None):
|
||||
"""
|
||||
create a train or eval imagenet2012 dataset for resnet50
|
||||
|
||||
Args:
|
||||
dataset_path(string): the path of dataset.
|
||||
do_train(bool): whether dataset is used for train or eval.
|
||||
repeat_num(int): the repeat times of dataset. Default: 1
|
||||
batch_size(int): the batch size of dataset. Default: 32
|
||||
target(str): the device target. Default: Ascend
|
||||
|
||||
Returns:
|
||||
dataset
|
||||
"""
|
||||
|
||||
device_num, rank_id = _get_rank_info()
|
||||
|
||||
if device_num == 1:
|
||||
data_set = ds.ImageFolderDatasetV2(dataset_path, num_parallel_workers=num_parallel_workers, shuffle=shuffle)
|
||||
else:
|
||||
data_set = ds.ImageFolderDatasetV2(dataset_path, num_parallel_workers=num_parallel_workers, shuffle=shuffle,
|
||||
num_shards=device_num, shard_id=rank_id)
|
||||
|
||||
assert imagenet_cfg.image_height == imagenet_cfg.image_width, "image_height not equal image_width"
|
||||
image_size = imagenet_cfg.image_height
|
||||
mean = [0.485 * 255, 0.456 * 255, 0.406 * 255]
|
||||
std = [0.229 * 255, 0.224 * 255, 0.225 * 255]
|
||||
|
||||
# define map operations
|
||||
if training:
|
||||
transform_img = [
|
||||
vision.RandomCropDecodeResize(image_size, scale=(0.08, 1.0), ratio=(0.75, 1.333)),
|
||||
vision.RandomHorizontalFlip(prob=0.5),
|
||||
vision.RandomColorAdjust(0.4, 0.4, 0.4, 0.1),
|
||||
vision.Normalize(mean=mean, std=std),
|
||||
vision.HWC2CHW()
|
||||
]
|
||||
else:
|
||||
transform_img = [
|
||||
vision.Decode(),
|
||||
vision.Resize(256),
|
||||
vision.CenterCrop(image_size),
|
||||
vision.Normalize(mean=mean, std=std),
|
||||
vision.HWC2CHW()
|
||||
]
|
||||
|
||||
transform_label = [C.TypeCast(mstype.int32)]
|
||||
|
||||
data_set = data_set.map(input_columns="image", num_parallel_workers=8, operations=transform_img)
|
||||
data_set = data_set.map(input_columns="label", num_parallel_workers=8, operations=transform_label)
|
||||
|
||||
# apply batch operations
|
||||
data_set = data_set.batch(imagenet_cfg.batch_size, drop_remainder=True)
|
||||
|
||||
# apply dataset repeat operation
|
||||
data_set = data_set.repeat(repeat_num)
|
||||
|
||||
return data_set
|
||||
|
||||
|
||||
def _get_rank_info():
|
||||
"""
|
||||
get rank size and rank id
|
||||
|
|
|
@ -112,6 +112,7 @@ class GoogleNet(nn.Cell):
|
|||
|
||||
|
||||
def construct(self, x):
|
||||
"""construct"""
|
||||
x = self.conv1(x)
|
||||
x = self.maxpool1(x)
|
||||
|
||||
|
|
|
@ -0,0 +1,20 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""lr"""
|
||||
|
||||
def linear_warmup_lr(current_step, warmup_steps, base_lr, init_lr):
|
||||
lr_inc = (float(base_lr) - float(init_lr)) / float(warmup_steps)
|
||||
lr = float(init_lr) + lr_inc * current_step
|
||||
return lr
|
|
@ -0,0 +1,39 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""lr"""
|
||||
|
||||
import math
|
||||
import numpy as np
|
||||
|
||||
from .linear_warmup import linear_warmup_lr
|
||||
|
||||
|
||||
def warmup_cosine_annealing_lr(lr, steps_per_epoch, warmup_epochs, max_epoch, T_max, eta_min=0):
|
||||
""" warmup cosine annealing lr"""
|
||||
base_lr = lr
|
||||
warmup_init_lr = 0
|
||||
total_steps = int(max_epoch * steps_per_epoch)
|
||||
warmup_steps = int(warmup_epochs * steps_per_epoch)
|
||||
|
||||
lr_each_step = []
|
||||
for i in range(total_steps):
|
||||
last_epoch = i // steps_per_epoch
|
||||
if i < warmup_steps:
|
||||
lr = linear_warmup_lr(i + 1, warmup_steps, base_lr, warmup_init_lr)
|
||||
else:
|
||||
lr = eta_min + (base_lr - eta_min) * (1. + math.cos(math.pi * last_epoch / T_max)) / 2
|
||||
lr_each_step.append(lr)
|
||||
|
||||
return np.array(lr_each_step).astype(np.float32)
|
|
@ -0,0 +1,59 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""lr"""
|
||||
|
||||
from collections import Counter
|
||||
import numpy as np
|
||||
|
||||
from .linear_warmup import linear_warmup_lr
|
||||
|
||||
|
||||
def warmup_step_lr(lr, lr_epochs, steps_per_epoch, warmup_epochs, max_epoch, gamma=0.1):
|
||||
"""warmup step lr"""
|
||||
base_lr = lr
|
||||
warmup_init_lr = 0
|
||||
total_steps = int(max_epoch * steps_per_epoch)
|
||||
warmup_steps = int(warmup_epochs * steps_per_epoch)
|
||||
milestones = lr_epochs
|
||||
milestones_steps = []
|
||||
for milestone in milestones:
|
||||
milestones_step = milestone * steps_per_epoch
|
||||
milestones_steps.append(milestones_step)
|
||||
|
||||
lr_each_step = []
|
||||
lr = base_lr
|
||||
milestones_steps_counter = Counter(milestones_steps)
|
||||
for i in range(total_steps):
|
||||
if i < warmup_steps:
|
||||
lr = linear_warmup_lr(i + 1, warmup_steps, base_lr, warmup_init_lr)
|
||||
else:
|
||||
lr = lr * gamma ** milestones_steps_counter[i]
|
||||
lr_each_step.append(lr)
|
||||
|
||||
return np.array(lr_each_step).astype(np.float32)
|
||||
|
||||
|
||||
def multi_step_lr(lr, milestones, steps_per_epoch, max_epoch, gamma=0.1):
|
||||
"""lr"""
|
||||
return warmup_step_lr(lr, milestones, steps_per_epoch, 0, max_epoch, gamma=gamma)
|
||||
|
||||
|
||||
def step_lr(lr, epoch_size, steps_per_epoch, max_epoch, gamma=0.1):
|
||||
"""lr"""
|
||||
lr_epochs = []
|
||||
for i in range(1, max_epoch):
|
||||
if i % epoch_size == 0:
|
||||
lr_epochs.append(i)
|
||||
return multi_step_lr(lr, lr_epochs, steps_per_epoch, max_epoch, gamma=gamma)
|
|
@ -27,18 +27,19 @@ from mindspore import context
|
|||
from mindspore.communication.management import init, get_rank
|
||||
from mindspore.nn.optim.momentum import Momentum
|
||||
from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor
|
||||
from mindspore.train.loss_scale_manager import DynamicLossScaleManager, FixedLossScaleManager
|
||||
from mindspore.train.model import Model
|
||||
from mindspore.context import ParallelMode
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from mindspore.common import set_seed
|
||||
|
||||
from src.config import cifar_cfg as cfg
|
||||
from src.dataset import create_dataset
|
||||
from src.config import cifar_cfg, imagenet_cfg
|
||||
from src.dataset import create_dataset_cifar10, create_dataset_imagenet
|
||||
from src.googlenet import GoogleNet
|
||||
|
||||
set_seed(1)
|
||||
|
||||
def lr_steps(global_step, lr_max=None, total_epochs=None, steps_per_epoch=None):
|
||||
def lr_steps_cifar10(global_step, lr_max=None, total_epochs=None, steps_per_epoch=None):
|
||||
"""Set learning rate."""
|
||||
lr_each_step = []
|
||||
total_steps = steps_per_epoch * total_epochs
|
||||
|
@ -59,11 +60,46 @@ def lr_steps(global_step, lr_max=None, total_epochs=None, steps_per_epoch=None):
|
|||
return learning_rate
|
||||
|
||||
|
||||
def lr_steps_imagenet(_cfg, steps_per_epoch):
|
||||
"""lr step for imagenet"""
|
||||
from src.lr_scheduler.warmup_step_lr import warmup_step_lr
|
||||
from src.lr_scheduler.warmup_cosine_annealing_lr import warmup_cosine_annealing_lr
|
||||
if _cfg.lr_scheduler == 'exponential':
|
||||
_lr = warmup_step_lr(_cfg.lr_init,
|
||||
_cfg.lr_epochs,
|
||||
steps_per_epoch,
|
||||
_cfg.warmup_epochs,
|
||||
_cfg.epoch_size,
|
||||
gamma=_cfg.lr_gamma,
|
||||
)
|
||||
elif _cfg.lr_scheduler == 'cosine_annealing':
|
||||
_lr = warmup_cosine_annealing_lr(_cfg.lr_init,
|
||||
steps_per_epoch,
|
||||
_cfg.warmup_epochs,
|
||||
_cfg.epoch_size,
|
||||
_cfg.T_max,
|
||||
_cfg.eta_min)
|
||||
else:
|
||||
raise NotImplementedError(_cfg.lr_scheduler)
|
||||
|
||||
return _lr
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser(description='Cifar10 classification')
|
||||
parser = argparse.ArgumentParser(description='Classification')
|
||||
parser.add_argument('--dataset_name', type=str, default='cifar10', choices=['imagenet', 'cifar10'],
|
||||
help='dataset name.')
|
||||
parser.add_argument('--device_id', type=int, default=None, help='device id of GPU or Ascend. (Default: None)')
|
||||
args_opt = parser.parse_args()
|
||||
|
||||
if args_opt.dataset_name == "cifar10":
|
||||
cfg = cifar_cfg
|
||||
elif args_opt.dataset_name == "imagenet":
|
||||
cfg = imagenet_cfg
|
||||
else:
|
||||
raise ValueError("Unsupport dataset.")
|
||||
|
||||
# set context
|
||||
device_target = cfg.device_target
|
||||
|
||||
context.set_context(mode=context.GRAPH_MODE, device_target=cfg.device_target)
|
||||
|
@ -90,7 +126,13 @@ if __name__ == '__main__':
|
|||
else:
|
||||
raise ValueError("Unsupported platform.")
|
||||
|
||||
dataset = create_dataset(cfg.data_path, 1)
|
||||
if args_opt.dataset_name == "cifar10":
|
||||
dataset = create_dataset_cifar10(cfg.data_path, 1)
|
||||
elif args_opt.dataset_name == "imagenet":
|
||||
dataset = create_dataset_imagenet(cfg.data_path, 1)
|
||||
else:
|
||||
raise ValueError("Unsupport dataset.")
|
||||
|
||||
batch_num = dataset.get_dataset_size()
|
||||
|
||||
net = GoogleNet(num_classes=cfg.num_classes)
|
||||
|
@ -98,23 +140,75 @@ if __name__ == '__main__':
|
|||
if cfg.pre_trained:
|
||||
param_dict = load_checkpoint(cfg.checkpoint_path)
|
||||
load_param_into_net(net, param_dict)
|
||||
lr = lr_steps(0, lr_max=cfg.lr_init, total_epochs=cfg.epoch_size, steps_per_epoch=batch_num)
|
||||
opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), Tensor(lr), cfg.momentum,
|
||||
|
||||
loss_scale_manager = None
|
||||
if args_opt.dataset_name == 'cifar10':
|
||||
lr = lr_steps_cifar10(0, lr_max=cfg.lr_init, total_epochs=cfg.epoch_size, steps_per_epoch=batch_num)
|
||||
opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()),
|
||||
learning_rate=Tensor(lr),
|
||||
momentum=cfg.momentum,
|
||||
weight_decay=cfg.weight_decay)
|
||||
loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
|
||||
loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean', is_grad=False)
|
||||
|
||||
elif args_opt.dataset_name == 'imagenet':
|
||||
lr = lr_steps_imagenet(cfg, batch_num)
|
||||
|
||||
|
||||
def get_param_groups(network):
|
||||
""" get param groups """
|
||||
decay_params = []
|
||||
no_decay_params = []
|
||||
for x in network.trainable_params():
|
||||
parameter_name = x.name
|
||||
if parameter_name.endswith('.bias'):
|
||||
# all bias not using weight decay
|
||||
# print('no decay:{}'.format(parameter_name))
|
||||
no_decay_params.append(x)
|
||||
elif parameter_name.endswith('.gamma'):
|
||||
# bn weight bias not using weight decay, be carefully for now x not include BN
|
||||
# print('no decay:{}'.format(parameter_name))
|
||||
no_decay_params.append(x)
|
||||
elif parameter_name.endswith('.beta'):
|
||||
# bn weight bias not using weight decay, be carefully for now x not include BN
|
||||
# print('no decay:{}'.format(parameter_name))
|
||||
no_decay_params.append(x)
|
||||
else:
|
||||
decay_params.append(x)
|
||||
|
||||
return [{'params': no_decay_params, 'weight_decay': 0.0}, {'params': decay_params}]
|
||||
|
||||
|
||||
if cfg.is_dynamic_loss_scale:
|
||||
cfg.loss_scale = 1
|
||||
|
||||
opt = Momentum(params=get_param_groups(net),
|
||||
learning_rate=Tensor(lr),
|
||||
momentum=cfg.momentum,
|
||||
weight_decay=cfg.weight_decay,
|
||||
loss_scale=cfg.loss_scale)
|
||||
if not cfg.use_label_smooth:
|
||||
cfg.label_smooth_factor = 0.0
|
||||
loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean",
|
||||
smooth_factor=cfg.label_smooth_factor, num_classes=cfg.num_classes)
|
||||
|
||||
if cfg.is_dynamic_loss_scale == 1:
|
||||
loss_scale_manager = DynamicLossScaleManager(init_loss_scale=65536, scale_factor=2, scale_window=2000)
|
||||
else:
|
||||
loss_scale_manager = FixedLossScaleManager(cfg.loss_scale, drop_overflow_update=False)
|
||||
|
||||
if device_target == "Ascend":
|
||||
model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'},
|
||||
amp_level="O2", keep_batchnorm_fp32=False, loss_scale_manager=None)
|
||||
amp_level="O2", keep_batchnorm_fp32=False, loss_scale_manager=loss_scale_manager)
|
||||
ckpt_save_dir = "./"
|
||||
else: # GPU
|
||||
model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'},
|
||||
amp_level="O2", keep_batchnorm_fp32=True, loss_scale_manager=None)
|
||||
amp_level="O2", keep_batchnorm_fp32=True, loss_scale_manager=loss_scale_manager)
|
||||
ckpt_save_dir = "./ckpt_" + str(get_rank()) + "/"
|
||||
|
||||
config_ck = CheckpointConfig(save_checkpoint_steps=batch_num * 5, keep_checkpoint_max=cfg.keep_checkpoint_max)
|
||||
time_cb = TimeMonitor(data_size=batch_num)
|
||||
ckpoint_cb = ModelCheckpoint(prefix="train_googlenet_cifar10", directory=ckpt_save_dir, config=config_ck)
|
||||
ckpoint_cb = ModelCheckpoint(prefix="train_googlenet_" + args_opt.dataset_name, directory=ckpt_save_dir,
|
||||
config=config_ck)
|
||||
loss_cb = LossMonitor()
|
||||
model.train(cfg.epoch_size, dataset, callbacks=[time_cb, ckpoint_cb, loss_cb])
|
||||
print("train success")
|
||||
|
|
Loading…
Reference in New Issue