modified resnet50 on imagenet2012 to improve the performance and accuracy

This commit is contained in:
guoqi 2020-07-25 09:19:06 +08:00
parent ea545dc52f
commit 63a4b77065
4 changed files with 12 additions and 48 deletions

View File

@ -22,7 +22,6 @@ from mindspore import dataset as de
from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits
from mindspore.train.model import Model from mindspore.train.model import Model
from mindspore.train.serialization import load_checkpoint, load_param_into_net from mindspore.train.serialization import load_checkpoint, load_param_into_net
from src.crossentropy import CrossEntropy
parser = argparse.ArgumentParser(description='Image classification') parser = argparse.ArgumentParser(description='Image classification')
parser.add_argument('--net', type=str, default=None, help='Resnet Model, either resnet50 or resnet101') parser.add_argument('--net', type=str, default=None, help='Resnet Model, either resnet50 or resnet101')
@ -78,7 +77,8 @@ if __name__ == '__main__':
if args_opt.dataset == "imagenet2012": if args_opt.dataset == "imagenet2012":
if not config.use_label_smooth: if not config.use_label_smooth:
config.label_smooth_factor = 0.0 config.label_smooth_factor = 0.0
loss = CrossEntropy(smooth_factor=config.label_smooth_factor, num_classes=config.class_num) loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean",
smooth_factor=config.label_smooth_factor, num_classes=config.class_num)
else: else:
loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean') loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')

View File

@ -1,39 +0,0 @@
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""define loss function for network"""
from mindspore.nn.loss.loss import _Loss
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore import Tensor
from mindspore.common import dtype as mstype
import mindspore.nn as nn
class CrossEntropy(_Loss):
"""the redefined loss function with SoftmaxCrossEntropyWithLogits"""
def __init__(self, smooth_factor=0., num_classes=1001):
super(CrossEntropy, self).__init__()
self.onehot = P.OneHot()
self.on_value = Tensor(1.0 - smooth_factor, mstype.float32)
self.off_value = Tensor(1.0 * smooth_factor / (num_classes - 1), mstype.float32)
self.ce = nn.SoftmaxCrossEntropyWithLogits()
self.mean = P.ReduceMean(False)
def construct(self, logit, label):
one_hot_label = self.onehot(label, F.shape(logit)[1], self.on_value, self.off_value)
loss = self.ce(logit, one_hot_label)
loss = self.mean(loss, 0)
return loss

View File

@ -121,7 +121,7 @@ def create_dataset2(dataset_path, do_train, repeat_num=1, batch_size=32, target=
else: else:
trans = [ trans = [
C.Decode(), C.Decode(),
C.Resize((256, 256)), C.Resize(256),
C.CenterCrop(image_size), C.CenterCrop(image_size),
C.Normalize(mean=mean, std=std), C.Normalize(mean=mean, std=std),
C.HWC2CHW() C.HWC2CHW()

View File

@ -31,7 +31,6 @@ from mindspore.communication.management import init, get_rank, get_group_size
import mindspore.nn as nn import mindspore.nn as nn
import mindspore.common.initializer as weight_init import mindspore.common.initializer as weight_init
from src.lr_generator import get_lr, warmup_cosine_annealing_lr from src.lr_generator import get_lr, warmup_cosine_annealing_lr
from src.crossentropy import CrossEntropy
parser = argparse.ArgumentParser(description='Image classification') parser = argparse.ArgumentParser(description='Image classification')
parser.add_argument('--net', type=str, default=None, help='Resnet Model, either resnet50 or resnet101') parser.add_argument('--net', type=str, default=None, help='Resnet Model, either resnet50 or resnet101')
@ -75,7 +74,7 @@ if __name__ == '__main__':
context.set_auto_parallel_context(device_num=args_opt.device_num, parallel_mode=ParallelMode.DATA_PARALLEL, context.set_auto_parallel_context(device_num=args_opt.device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
mirror_mean=True) mirror_mean=True)
if args_opt.net == "resnet50": if args_opt.net == "resnet50":
auto_parallel_context().set_all_reduce_fusion_split_indices([107, 160]) auto_parallel_context().set_all_reduce_fusion_split_indices([85, 160])
else: else:
auto_parallel_context().set_all_reduce_fusion_split_indices([180, 313]) auto_parallel_context().set_all_reduce_fusion_split_indices([180, 313])
init() init()
@ -128,15 +127,19 @@ if __name__ == '__main__':
lr = Tensor(lr) lr = Tensor(lr)
# define opt # define opt
opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, config.momentum, decayed_params = list(filter(lambda x: 'beta' not in x.name and 'gamma' not in x.name and 'bias' not in x.name, net.trainalbe_params()))
config.weight_decay, config.loss_scale) no_decayed_params = [param for param in net.trainalbe_params() if param not in decayed_params]
group_params = [{'params': decayed_params, 'weight_decay': config.weight_decay},
{'params': no_decayed_params},
{'order_params': net.trainalbe_params()}]
opt = Momentum(group_params, lr, config.momentum, loss_scale=config.loss_scale)
# define loss, model # define loss, model
if target == "Ascend": if target == "Ascend":
if args_opt.dataset == "imagenet2012": if args_opt.dataset == "imagenet2012":
if not config.use_label_smooth: if not config.use_label_smooth:
config.label_smooth_factor = 0.0 config.label_smooth_factor = 0.0
loss = CrossEntropy(smooth_factor=config.label_smooth_factor, num_classes=config.class_num) loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean",
smooth_factor=config.label_smooth_factor, num_classes=config.class_num)
else: else:
loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean') loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
loss_scale = FixedLossScaleManager(config.loss_scale, drop_overflow_update=False) loss_scale = FixedLossScaleManager(config.loss_scale, drop_overflow_update=False)