diff --git a/model_zoo/official/cv/googlenet/eval.py b/model_zoo/official/cv/googlenet/eval.py index 21c902c466..79b8debcbe 100644 --- a/model_zoo/official/cv/googlenet/eval.py +++ b/model_zoo/official/cv/googlenet/eval.py @@ -25,35 +25,57 @@ from mindspore.train.model import Model from mindspore.train.serialization import load_checkpoint, load_param_into_net from mindspore.common import set_seed -from src.config import cifar_cfg as cfg -from src.dataset import create_dataset +from src.config import cifar_cfg, imagenet_cfg +from src.dataset import create_dataset_cifar10, create_dataset_imagenet + from src.googlenet import GoogleNet set_seed(1) parser = argparse.ArgumentParser(description='googlenet') +parser.add_argument('--dataset_name', type=str, default='cifar10', choices=['imagenet', 'cifar10'], + help='dataset name.') parser.add_argument('--checkpoint_path', type=str, default=None, help='Checkpoint file path') args_opt = parser.parse_args() if __name__ == '__main__': + + if args_opt.dataset_name == 'cifar10': + cfg = cifar_cfg + dataset = create_dataset_cifar10(cfg.data_path, 1, False) + loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean', is_grad=False) + net = GoogleNet(num_classes=cfg.num_classes) + opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 0.01, cfg.momentum, + weight_decay=cfg.weight_decay) + model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'}) + + elif args_opt.dataset_name == "imagenet": + cfg = imagenet_cfg + dataset = create_dataset_imagenet(cfg.val_data_path, 1, False) + if not cfg.use_label_smooth: + cfg.label_smooth_factor = 0.0 + loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean", + smooth_factor=cfg.label_smooth_factor, num_classes=cfg.num_classes) + net = GoogleNet(num_classes=cfg.num_classes) + model = Model(net, loss_fn=loss, metrics={'top_1_accuracy', 'top_5_accuracy'}) + + else: + raise ValueError("dataset is not support.") + device_target = cfg.device_target context.set_context(mode=context.GRAPH_MODE, device_target=cfg.device_target) if device_target == "Ascend": context.set_context(device_id=cfg.device_id) - net = GoogleNet(num_classes=cfg.num_classes) - opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), 0.01, cfg.momentum, - weight_decay=cfg.weight_decay) - loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean') - model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'}) - - if device_target == "Ascend": - param_dict = load_checkpoint(cfg.checkpoint_path) - else: # GPU + if args_opt.checkpoint_path is not None: param_dict = load_checkpoint(args_opt.checkpoint_path) + print("load checkpoint from [{}].".format(args_opt.checkpoint_path)) + else: + param_dict = load_checkpoint(cfg.checkpoint_path) + print("load checkpoint from [{}].".format(cfg.checkpoint_path)) load_param_into_net(net, param_dict) net.set_train(False) - dataset = create_dataset(cfg.data_path, 1, False) + acc = model.eval(dataset) print("accuracy: ", acc) diff --git a/model_zoo/official/cv/googlenet/export.py b/model_zoo/official/cv/googlenet/export.py index 49a07872ee..2b77fb9b92 100644 --- a/model_zoo/official/cv/googlenet/export.py +++ b/model_zoo/official/cv/googlenet/export.py @@ -16,18 +16,32 @@ ##############export checkpoint file into air and onnx models################# python export.py """ +import argparse import numpy as np import mindspore as ms from mindspore import Tensor from mindspore.train.serialization import load_checkpoint, load_param_into_net, export -from src.config import cifar_cfg as cfg +from src.config import cifar_cfg, imagenet_cfg from src.googlenet import GoogleNet - if __name__ == '__main__': + parser = argparse.ArgumentParser(description='Classification') + parser.add_argument('--dataset_name', type=str, default='cifar10', choices=['imagenet', 'cifar10'], + help='dataset name.') + args_opt = parser.parse_args() + + if args_opt.dataset_name == 'cifar10': + cfg = cifar_cfg + elif args_opt.dataset_name == 'imagenet': + cfg = imagenet_cfg + else: + raise ValueError("dataset is not support.") + net = GoogleNet(num_classes=cfg.num_classes) + + assert cfg.checkpoint_path is not None, "cfg.checkpoint_path is None." param_dict = load_checkpoint(cfg.checkpoint_path) load_param_into_net(net, param_dict) diff --git a/model_zoo/official/cv/googlenet/scripts/run_train.sh b/model_zoo/official/cv/googlenet/scripts/run_train.sh index ed8a0e5f2a..1823e6ecc2 100644 --- a/model_zoo/official/cv/googlenet/scripts/run_train.sh +++ b/model_zoo/official/cv/googlenet/scripts/run_train.sh @@ -14,9 +14,9 @@ # limitations under the License. # ============================================================================ -if [ $# != 1 ] +if [ $# != 1 ] && [ $# != 2 ] then - echo "Usage: sh run_train.sh [RANK_TABLE_FILE]" + echo "Usage: sh run_train.sh [RANK_TABLE_FILE] [cifar10|imagenet]" exit 1 fi @@ -26,6 +26,19 @@ then exit 1 fi + +dataset_type='cifar10' +if [ $# == 2 ] +then + if [ $2 != "cifar10" ] && [ $2 != "imagenet" ] + then + echo "error: the selected dataset is neither cifar10 nor imagenet" + exit 1 + fi + dataset_type=$2 +fi + + ulimit -u unlimited export DEVICE_NUM=8 export RANK_SIZE=8 @@ -43,9 +56,9 @@ do mkdir ./train_parallel$i cp -r ./src ./train_parallel$i cp ./train.py ./train_parallel$i - echo "start training for rank $RANK_ID, device $DEVICE_ID" + echo "start training for rank $RANK_ID, device $DEVICE_ID, $dataset_type" cd ./train_parallel$i ||exit env > env.log - python train.py --device_id=$i > log 2>&1 & + python train.py --device_id=$i --dataset_name=$dataset_type> log 2>&1 & cd .. done diff --git a/model_zoo/official/cv/googlenet/src/config.py b/model_zoo/official/cv/googlenet/src/config.py index 3c37afab34..2989e2ba69 100644 --- a/model_zoo/official/cv/googlenet/src/config.py +++ b/model_zoo/official/cv/googlenet/src/config.py @@ -18,6 +18,7 @@ network config setting, will be used in main.py from easydict import EasyDict as edict cifar_cfg = edict({ + 'name': 'cifar10', 'pre_trained': False, 'num_classes': 10, 'lr_init': 0.1, @@ -30,9 +31,45 @@ cifar_cfg = edict({ 'image_width': 224, 'data_path': './cifar10', 'device_target': 'Ascend', - 'device_id': 4, + 'device_id': 0, 'keep_checkpoint_max': 10, 'checkpoint_path': './train_googlenet_cifar10-125_390.ckpt', 'onnx_filename': 'googlenet.onnx', 'air_filename': 'googlenet.air' }) + +imagenet_cfg = edict({ + 'name': 'imagenet', + 'pre_trained': False, + 'num_classes': 1000, + 'lr_init': 0.1, + 'batch_size': 256, + 'epoch_size': 300, + 'momentum': 0.9, + 'weight_decay': 1e-4, + 'buffer_size': None, # invalid parameter + 'image_height': 224, + 'image_width': 224, + 'data_path': './ImageNet_Original/train/', + 'val_data_path': './ImageNet_Original/val/', + 'device_target': 'Ascend', + 'device_id': 0, + 'keep_checkpoint_max': 10, + 'checkpoint_path': None, + 'onnx_filename': 'googlenet.onnx', + 'air_filename': 'googlenet.air', + + # optimizer and lr related + 'lr_scheduler': 'exponential', + 'lr_epochs': [70, 140, 210, 280], + 'lr_gamma': 0.3, + 'eta_min': 0.0, + 'T_max': 150, + 'warmup_epochs': 0, + + # loss related + 'is_dynamic_loss_scale': 0, + 'loss_scale': 1024, + 'label_smooth_factor': 0.1, + 'use_label_smooth': True, +}) diff --git a/model_zoo/official/cv/googlenet/src/dataset.py b/model_zoo/official/cv/googlenet/src/dataset.py index 771624c585..7c1553a15a 100644 --- a/model_zoo/official/cv/googlenet/src/dataset.py +++ b/model_zoo/official/cv/googlenet/src/dataset.py @@ -21,10 +21,10 @@ import mindspore.common.dtype as mstype import mindspore.dataset as ds import mindspore.dataset.transforms.c_transforms as C import mindspore.dataset.vision.c_transforms as vision -from src.config import cifar_cfg as cfg +from src.config import cifar_cfg, imagenet_cfg -def create_dataset(data_home, repeat_num=1, training=True): +def create_dataset_cifar10(data_home, repeat_num=1, training=True): """Data operations.""" ds.config.set_seed(1) data_dir = os.path.join(data_home, "cifar-10-batches-bin") @@ -37,14 +37,14 @@ def create_dataset(data_home, repeat_num=1, training=True): else: data_set = ds.Cifar10Dataset(data_dir, num_shards=rank_size, shard_id=rank_id, shuffle=False) - resize_height = cfg.image_height - resize_width = cfg.image_width + resize_height = cifar_cfg.image_height + resize_width = cifar_cfg.image_width # define map operations random_crop_op = vision.RandomCrop((32, 32), (4, 4, 4, 4)) # padding_mode default CONSTANT random_horizontal_op = vision.RandomHorizontalFlip() resize_op = vision.Resize((resize_height, resize_width)) # interpolation default BILINEAR - rescale_op = vision.Rescale(1.0/255.0, 0.0) + rescale_op = vision.Rescale(1.0 / 255.0, 0.0) normalize_op = vision.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)) changeswap_op = vision.HWC2CHW() type_cast_op = C.TypeCast(mstype.int32) @@ -59,7 +59,7 @@ def create_dataset(data_home, repeat_num=1, training=True): data_set = data_set.map(input_columns="image", operations=c_trans) # apply batch operations - data_set = data_set.batch(batch_size=cfg.batch_size, drop_remainder=True) + data_set = data_set.batch(batch_size=cifar_cfg.batch_size, drop_remainder=True) # apply repeat operations data_set = data_set.repeat(repeat_num) @@ -67,6 +67,67 @@ def create_dataset(data_home, repeat_num=1, training=True): return data_set +def create_dataset_imagenet(dataset_path, repeat_num=1, training=True, + num_parallel_workers=None, shuffle=None): + """ + create a train or eval imagenet2012 dataset for resnet50 + + Args: + dataset_path(string): the path of dataset. + do_train(bool): whether dataset is used for train or eval. + repeat_num(int): the repeat times of dataset. Default: 1 + batch_size(int): the batch size of dataset. Default: 32 + target(str): the device target. Default: Ascend + + Returns: + dataset + """ + + device_num, rank_id = _get_rank_info() + + if device_num == 1: + data_set = ds.ImageFolderDatasetV2(dataset_path, num_parallel_workers=num_parallel_workers, shuffle=shuffle) + else: + data_set = ds.ImageFolderDatasetV2(dataset_path, num_parallel_workers=num_parallel_workers, shuffle=shuffle, + num_shards=device_num, shard_id=rank_id) + + assert imagenet_cfg.image_height == imagenet_cfg.image_width, "image_height not equal image_width" + image_size = imagenet_cfg.image_height + mean = [0.485 * 255, 0.456 * 255, 0.406 * 255] + std = [0.229 * 255, 0.224 * 255, 0.225 * 255] + + # define map operations + if training: + transform_img = [ + vision.RandomCropDecodeResize(image_size, scale=(0.08, 1.0), ratio=(0.75, 1.333)), + vision.RandomHorizontalFlip(prob=0.5), + vision.RandomColorAdjust(0.4, 0.4, 0.4, 0.1), + vision.Normalize(mean=mean, std=std), + vision.HWC2CHW() + ] + else: + transform_img = [ + vision.Decode(), + vision.Resize(256), + vision.CenterCrop(image_size), + vision.Normalize(mean=mean, std=std), + vision.HWC2CHW() + ] + + transform_label = [C.TypeCast(mstype.int32)] + + data_set = data_set.map(input_columns="image", num_parallel_workers=8, operations=transform_img) + data_set = data_set.map(input_columns="label", num_parallel_workers=8, operations=transform_label) + + # apply batch operations + data_set = data_set.batch(imagenet_cfg.batch_size, drop_remainder=True) + + # apply dataset repeat operation + data_set = data_set.repeat(repeat_num) + + return data_set + + def _get_rank_info(): """ get rank size and rank id diff --git a/model_zoo/official/cv/googlenet/src/googlenet.py b/model_zoo/official/cv/googlenet/src/googlenet.py index 69b26b0daf..78695f2d6c 100644 --- a/model_zoo/official/cv/googlenet/src/googlenet.py +++ b/model_zoo/official/cv/googlenet/src/googlenet.py @@ -112,6 +112,7 @@ class GoogleNet(nn.Cell): def construct(self, x): + """construct""" x = self.conv1(x) x = self.maxpool1(x) diff --git a/model_zoo/official/cv/googlenet/src/lr_scheduler/__init__.py b/model_zoo/official/cv/googlenet/src/lr_scheduler/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/model_zoo/official/cv/googlenet/src/lr_scheduler/linear_warmup.py b/model_zoo/official/cv/googlenet/src/lr_scheduler/linear_warmup.py new file mode 100644 index 0000000000..78e7b85f6d --- /dev/null +++ b/model_zoo/official/cv/googlenet/src/lr_scheduler/linear_warmup.py @@ -0,0 +1,20 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""lr""" + +def linear_warmup_lr(current_step, warmup_steps, base_lr, init_lr): + lr_inc = (float(base_lr) - float(init_lr)) / float(warmup_steps) + lr = float(init_lr) + lr_inc * current_step + return lr diff --git a/model_zoo/official/cv/googlenet/src/lr_scheduler/warmup_cosine_annealing_lr.py b/model_zoo/official/cv/googlenet/src/lr_scheduler/warmup_cosine_annealing_lr.py new file mode 100644 index 0000000000..349270b6e1 --- /dev/null +++ b/model_zoo/official/cv/googlenet/src/lr_scheduler/warmup_cosine_annealing_lr.py @@ -0,0 +1,39 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""lr""" + +import math +import numpy as np + +from .linear_warmup import linear_warmup_lr + + +def warmup_cosine_annealing_lr(lr, steps_per_epoch, warmup_epochs, max_epoch, T_max, eta_min=0): + """ warmup cosine annealing lr""" + base_lr = lr + warmup_init_lr = 0 + total_steps = int(max_epoch * steps_per_epoch) + warmup_steps = int(warmup_epochs * steps_per_epoch) + + lr_each_step = [] + for i in range(total_steps): + last_epoch = i // steps_per_epoch + if i < warmup_steps: + lr = linear_warmup_lr(i + 1, warmup_steps, base_lr, warmup_init_lr) + else: + lr = eta_min + (base_lr - eta_min) * (1. + math.cos(math.pi * last_epoch / T_max)) / 2 + lr_each_step.append(lr) + + return np.array(lr_each_step).astype(np.float32) diff --git a/model_zoo/official/cv/googlenet/src/lr_scheduler/warmup_step_lr.py b/model_zoo/official/cv/googlenet/src/lr_scheduler/warmup_step_lr.py new file mode 100644 index 0000000000..df78f17d79 --- /dev/null +++ b/model_zoo/official/cv/googlenet/src/lr_scheduler/warmup_step_lr.py @@ -0,0 +1,59 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""lr""" + +from collections import Counter +import numpy as np + +from .linear_warmup import linear_warmup_lr + + +def warmup_step_lr(lr, lr_epochs, steps_per_epoch, warmup_epochs, max_epoch, gamma=0.1): + """warmup step lr""" + base_lr = lr + warmup_init_lr = 0 + total_steps = int(max_epoch * steps_per_epoch) + warmup_steps = int(warmup_epochs * steps_per_epoch) + milestones = lr_epochs + milestones_steps = [] + for milestone in milestones: + milestones_step = milestone * steps_per_epoch + milestones_steps.append(milestones_step) + + lr_each_step = [] + lr = base_lr + milestones_steps_counter = Counter(milestones_steps) + for i in range(total_steps): + if i < warmup_steps: + lr = linear_warmup_lr(i + 1, warmup_steps, base_lr, warmup_init_lr) + else: + lr = lr * gamma ** milestones_steps_counter[i] + lr_each_step.append(lr) + + return np.array(lr_each_step).astype(np.float32) + + +def multi_step_lr(lr, milestones, steps_per_epoch, max_epoch, gamma=0.1): + """lr""" + return warmup_step_lr(lr, milestones, steps_per_epoch, 0, max_epoch, gamma=gamma) + + +def step_lr(lr, epoch_size, steps_per_epoch, max_epoch, gamma=0.1): + """lr""" + lr_epochs = [] + for i in range(1, max_epoch): + if i % epoch_size == 0: + lr_epochs.append(i) + return multi_step_lr(lr, lr_epochs, steps_per_epoch, max_epoch, gamma=gamma) diff --git a/model_zoo/official/cv/googlenet/train.py b/model_zoo/official/cv/googlenet/train.py index 53de6ef9c4..b370177530 100644 --- a/model_zoo/official/cv/googlenet/train.py +++ b/model_zoo/official/cv/googlenet/train.py @@ -27,18 +27,19 @@ from mindspore import context from mindspore.communication.management import init, get_rank from mindspore.nn.optim.momentum import Momentum from mindspore.train.callback import ModelCheckpoint, CheckpointConfig, LossMonitor, TimeMonitor +from mindspore.train.loss_scale_manager import DynamicLossScaleManager, FixedLossScaleManager from mindspore.train.model import Model from mindspore.context import ParallelMode from mindspore.train.serialization import load_checkpoint, load_param_into_net from mindspore.common import set_seed -from src.config import cifar_cfg as cfg -from src.dataset import create_dataset +from src.config import cifar_cfg, imagenet_cfg +from src.dataset import create_dataset_cifar10, create_dataset_imagenet from src.googlenet import GoogleNet set_seed(1) -def lr_steps(global_step, lr_max=None, total_epochs=None, steps_per_epoch=None): +def lr_steps_cifar10(global_step, lr_max=None, total_epochs=None, steps_per_epoch=None): """Set learning rate.""" lr_each_step = [] total_steps = steps_per_epoch * total_epochs @@ -59,11 +60,46 @@ def lr_steps(global_step, lr_max=None, total_epochs=None, steps_per_epoch=None): return learning_rate +def lr_steps_imagenet(_cfg, steps_per_epoch): + """lr step for imagenet""" + from src.lr_scheduler.warmup_step_lr import warmup_step_lr + from src.lr_scheduler.warmup_cosine_annealing_lr import warmup_cosine_annealing_lr + if _cfg.lr_scheduler == 'exponential': + _lr = warmup_step_lr(_cfg.lr_init, + _cfg.lr_epochs, + steps_per_epoch, + _cfg.warmup_epochs, + _cfg.epoch_size, + gamma=_cfg.lr_gamma, + ) + elif _cfg.lr_scheduler == 'cosine_annealing': + _lr = warmup_cosine_annealing_lr(_cfg.lr_init, + steps_per_epoch, + _cfg.warmup_epochs, + _cfg.epoch_size, + _cfg.T_max, + _cfg.eta_min) + else: + raise NotImplementedError(_cfg.lr_scheduler) + + return _lr + + if __name__ == '__main__': - parser = argparse.ArgumentParser(description='Cifar10 classification') + parser = argparse.ArgumentParser(description='Classification') + parser.add_argument('--dataset_name', type=str, default='cifar10', choices=['imagenet', 'cifar10'], + help='dataset name.') parser.add_argument('--device_id', type=int, default=None, help='device id of GPU or Ascend. (Default: None)') args_opt = parser.parse_args() + if args_opt.dataset_name == "cifar10": + cfg = cifar_cfg + elif args_opt.dataset_name == "imagenet": + cfg = imagenet_cfg + else: + raise ValueError("Unsupport dataset.") + + # set context device_target = cfg.device_target context.set_context(mode=context.GRAPH_MODE, device_target=cfg.device_target) @@ -90,7 +126,13 @@ if __name__ == '__main__': else: raise ValueError("Unsupported platform.") - dataset = create_dataset(cfg.data_path, 1) + if args_opt.dataset_name == "cifar10": + dataset = create_dataset_cifar10(cfg.data_path, 1) + elif args_opt.dataset_name == "imagenet": + dataset = create_dataset_imagenet(cfg.data_path, 1) + else: + raise ValueError("Unsupport dataset.") + batch_num = dataset.get_dataset_size() net = GoogleNet(num_classes=cfg.num_classes) @@ -98,23 +140,75 @@ if __name__ == '__main__': if cfg.pre_trained: param_dict = load_checkpoint(cfg.checkpoint_path) load_param_into_net(net, param_dict) - lr = lr_steps(0, lr_max=cfg.lr_init, total_epochs=cfg.epoch_size, steps_per_epoch=batch_num) - opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), Tensor(lr), cfg.momentum, - weight_decay=cfg.weight_decay) - loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean') + + loss_scale_manager = None + if args_opt.dataset_name == 'cifar10': + lr = lr_steps_cifar10(0, lr_max=cfg.lr_init, total_epochs=cfg.epoch_size, steps_per_epoch=batch_num) + opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), + learning_rate=Tensor(lr), + momentum=cfg.momentum, + weight_decay=cfg.weight_decay) + loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean', is_grad=False) + + elif args_opt.dataset_name == 'imagenet': + lr = lr_steps_imagenet(cfg, batch_num) + + + def get_param_groups(network): + """ get param groups """ + decay_params = [] + no_decay_params = [] + for x in network.trainable_params(): + parameter_name = x.name + if parameter_name.endswith('.bias'): + # all bias not using weight decay + # print('no decay:{}'.format(parameter_name)) + no_decay_params.append(x) + elif parameter_name.endswith('.gamma'): + # bn weight bias not using weight decay, be carefully for now x not include BN + # print('no decay:{}'.format(parameter_name)) + no_decay_params.append(x) + elif parameter_name.endswith('.beta'): + # bn weight bias not using weight decay, be carefully for now x not include BN + # print('no decay:{}'.format(parameter_name)) + no_decay_params.append(x) + else: + decay_params.append(x) + + return [{'params': no_decay_params, 'weight_decay': 0.0}, {'params': decay_params}] + + + if cfg.is_dynamic_loss_scale: + cfg.loss_scale = 1 + + opt = Momentum(params=get_param_groups(net), + learning_rate=Tensor(lr), + momentum=cfg.momentum, + weight_decay=cfg.weight_decay, + loss_scale=cfg.loss_scale) + if not cfg.use_label_smooth: + cfg.label_smooth_factor = 0.0 + loss = nn.SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean", + smooth_factor=cfg.label_smooth_factor, num_classes=cfg.num_classes) + + if cfg.is_dynamic_loss_scale == 1: + loss_scale_manager = DynamicLossScaleManager(init_loss_scale=65536, scale_factor=2, scale_window=2000) + else: + loss_scale_manager = FixedLossScaleManager(cfg.loss_scale, drop_overflow_update=False) if device_target == "Ascend": model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'}, - amp_level="O2", keep_batchnorm_fp32=False, loss_scale_manager=None) + amp_level="O2", keep_batchnorm_fp32=False, loss_scale_manager=loss_scale_manager) ckpt_save_dir = "./" - else: # GPU + else: # GPU model = Model(net, loss_fn=loss, optimizer=opt, metrics={'acc'}, - amp_level="O2", keep_batchnorm_fp32=True, loss_scale_manager=None) + amp_level="O2", keep_batchnorm_fp32=True, loss_scale_manager=loss_scale_manager) ckpt_save_dir = "./ckpt_" + str(get_rank()) + "/" config_ck = CheckpointConfig(save_checkpoint_steps=batch_num * 5, keep_checkpoint_max=cfg.keep_checkpoint_max) time_cb = TimeMonitor(data_size=batch_num) - ckpoint_cb = ModelCheckpoint(prefix="train_googlenet_cifar10", directory=ckpt_save_dir, config=config_ck) + ckpoint_cb = ModelCheckpoint(prefix="train_googlenet_" + args_opt.dataset_name, directory=ckpt_save_dir, + config=config_ck) loss_cb = LossMonitor() model.train(cfg.epoch_size, dataset, callbacks=[time_cb, ckpoint_cb, loss_cb]) print("train success")