forked from mindspore-Ecosystem/mindspore
!3456 modify resnet50 on imagenet to improve the performance and accuracy
Merge pull request !3456 from guoqi/master
This commit is contained in:
commit
7fbed0ce94
|
@ -22,7 +22,6 @@ from mindspore import dataset as de
|
||||||
from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits
|
from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits
|
||||||
from mindspore.train.model import Model
|
from mindspore.train.model import Model
|
||||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||||
from src.crossentropy import CrossEntropy
|
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description='Image classification')
|
parser = argparse.ArgumentParser(description='Image classification')
|
||||||
parser.add_argument('--net', type=str, default=None, help='Resnet Model, either resnet50 or resnet101')
|
parser.add_argument('--net', type=str, default=None, help='Resnet Model, either resnet50 or resnet101')
|
||||||
|
@ -78,7 +77,8 @@ if __name__ == '__main__':
|
||||||
if args_opt.dataset == "imagenet2012":
|
if args_opt.dataset == "imagenet2012":
|
||||||
if not config.use_label_smooth:
|
if not config.use_label_smooth:
|
||||||
config.label_smooth_factor = 0.0
|
config.label_smooth_factor = 0.0
|
||||||
loss = CrossEntropy(smooth_factor=config.label_smooth_factor, num_classes=config.class_num)
|
loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean",
|
||||||
|
smooth_factor=config.label_smooth_factor, num_classes=config.class_num)
|
||||||
else:
|
else:
|
||||||
loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
|
loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
|
||||||
|
|
||||||
|
|
|
@ -1,39 +0,0 @@
|
||||||
# Copyright 2020 Huawei Technologies Co., Ltd
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
# ============================================================================
|
|
||||||
"""define loss function for network"""
|
|
||||||
from mindspore.nn.loss.loss import _Loss
|
|
||||||
from mindspore.ops import operations as P
|
|
||||||
from mindspore.ops import functional as F
|
|
||||||
from mindspore import Tensor
|
|
||||||
from mindspore.common import dtype as mstype
|
|
||||||
import mindspore.nn as nn
|
|
||||||
|
|
||||||
|
|
||||||
class CrossEntropy(_Loss):
|
|
||||||
"""the redefined loss function with SoftmaxCrossEntropyWithLogits"""
|
|
||||||
|
|
||||||
def __init__(self, smooth_factor=0., num_classes=1001):
|
|
||||||
super(CrossEntropy, self).__init__()
|
|
||||||
self.onehot = P.OneHot()
|
|
||||||
self.on_value = Tensor(1.0 - smooth_factor, mstype.float32)
|
|
||||||
self.off_value = Tensor(1.0 * smooth_factor / (num_classes - 1), mstype.float32)
|
|
||||||
self.ce = nn.SoftmaxCrossEntropyWithLogits()
|
|
||||||
self.mean = P.ReduceMean(False)
|
|
||||||
|
|
||||||
def construct(self, logit, label):
|
|
||||||
one_hot_label = self.onehot(label, F.shape(logit)[1], self.on_value, self.off_value)
|
|
||||||
loss = self.ce(logit, one_hot_label)
|
|
||||||
loss = self.mean(loss, 0)
|
|
||||||
return loss
|
|
|
@ -121,7 +121,7 @@ def create_dataset2(dataset_path, do_train, repeat_num=1, batch_size=32, target=
|
||||||
else:
|
else:
|
||||||
trans = [
|
trans = [
|
||||||
C.Decode(),
|
C.Decode(),
|
||||||
C.Resize((256, 256)),
|
C.Resize(256),
|
||||||
C.CenterCrop(image_size),
|
C.CenterCrop(image_size),
|
||||||
C.Normalize(mean=mean, std=std),
|
C.Normalize(mean=mean, std=std),
|
||||||
C.HWC2CHW()
|
C.HWC2CHW()
|
||||||
|
|
|
@ -31,7 +31,6 @@ from mindspore.communication.management import init, get_rank, get_group_size
|
||||||
import mindspore.nn as nn
|
import mindspore.nn as nn
|
||||||
import mindspore.common.initializer as weight_init
|
import mindspore.common.initializer as weight_init
|
||||||
from src.lr_generator import get_lr, warmup_cosine_annealing_lr
|
from src.lr_generator import get_lr, warmup_cosine_annealing_lr
|
||||||
from src.crossentropy import CrossEntropy
|
|
||||||
|
|
||||||
parser = argparse.ArgumentParser(description='Image classification')
|
parser = argparse.ArgumentParser(description='Image classification')
|
||||||
parser.add_argument('--net', type=str, default=None, help='Resnet Model, either resnet50 or resnet101')
|
parser.add_argument('--net', type=str, default=None, help='Resnet Model, either resnet50 or resnet101')
|
||||||
|
@ -75,7 +74,7 @@ if __name__ == '__main__':
|
||||||
context.set_auto_parallel_context(device_num=args_opt.device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
|
context.set_auto_parallel_context(device_num=args_opt.device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
|
||||||
mirror_mean=True)
|
mirror_mean=True)
|
||||||
if args_opt.net == "resnet50":
|
if args_opt.net == "resnet50":
|
||||||
auto_parallel_context().set_all_reduce_fusion_split_indices([107, 160])
|
auto_parallel_context().set_all_reduce_fusion_split_indices([85, 160])
|
||||||
else:
|
else:
|
||||||
auto_parallel_context().set_all_reduce_fusion_split_indices([180, 313])
|
auto_parallel_context().set_all_reduce_fusion_split_indices([180, 313])
|
||||||
init()
|
init()
|
||||||
|
@ -128,15 +127,19 @@ if __name__ == '__main__':
|
||||||
lr = Tensor(lr)
|
lr = Tensor(lr)
|
||||||
|
|
||||||
# define opt
|
# define opt
|
||||||
opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, config.momentum,
|
decayed_params = list(filter(lambda x: 'beta' not in x.name and 'gamma' not in x.name and 'bias' not in x.name, net.trainalbe_params()))
|
||||||
config.weight_decay, config.loss_scale)
|
no_decayed_params = [param for param in net.trainalbe_params() if param not in decayed_params]
|
||||||
|
group_params = [{'params': decayed_params, 'weight_decay': config.weight_decay},
|
||||||
|
{'params': no_decayed_params},
|
||||||
|
{'order_params': net.trainalbe_params()}]
|
||||||
|
opt = Momentum(group_params, lr, config.momentum, loss_scale=config.loss_scale)
|
||||||
# define loss, model
|
# define loss, model
|
||||||
if target == "Ascend":
|
if target == "Ascend":
|
||||||
if args_opt.dataset == "imagenet2012":
|
if args_opt.dataset == "imagenet2012":
|
||||||
if not config.use_label_smooth:
|
if not config.use_label_smooth:
|
||||||
config.label_smooth_factor = 0.0
|
config.label_smooth_factor = 0.0
|
||||||
loss = CrossEntropy(smooth_factor=config.label_smooth_factor, num_classes=config.class_num)
|
loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean",
|
||||||
|
smooth_factor=config.label_smooth_factor, num_classes=config.class_num)
|
||||||
else:
|
else:
|
||||||
loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
|
loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
|
||||||
loss_scale = FixedLossScaleManager(config.loss_scale, drop_overflow_update=False)
|
loss_scale = FixedLossScaleManager(config.loss_scale, drop_overflow_update=False)
|
||||||
|
|
Loading…
Reference in New Issue