forked from OSSInnovation/mindspore
!3456 modify resnet50 on imagenet to improve the performance and accuracy
Merge pull request !3456 from guoqi/master
This commit is contained in:
commit
7fbed0ce94
|
@ -22,7 +22,6 @@ from mindspore import dataset as de
|
|||
from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits
|
||||
from mindspore.train.model import Model
|
||||
from mindspore.train.serialization import load_checkpoint, load_param_into_net
|
||||
from src.crossentropy import CrossEntropy
|
||||
|
||||
parser = argparse.ArgumentParser(description='Image classification')
|
||||
parser.add_argument('--net', type=str, default=None, help='Resnet Model, either resnet50 or resnet101')
|
||||
|
@ -78,7 +77,8 @@ if __name__ == '__main__':
|
|||
if args_opt.dataset == "imagenet2012":
|
||||
if not config.use_label_smooth:
|
||||
config.label_smooth_factor = 0.0
|
||||
loss = CrossEntropy(smooth_factor=config.label_smooth_factor, num_classes=config.class_num)
|
||||
loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean",
|
||||
smooth_factor=config.label_smooth_factor, num_classes=config.class_num)
|
||||
else:
|
||||
loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
|
||||
|
||||
|
|
|
@ -1,39 +0,0 @@
|
|||
# Copyright 2020 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
"""define loss function for network"""
|
||||
from mindspore.nn.loss.loss import _Loss
|
||||
from mindspore.ops import operations as P
|
||||
from mindspore.ops import functional as F
|
||||
from mindspore import Tensor
|
||||
from mindspore.common import dtype as mstype
|
||||
import mindspore.nn as nn
|
||||
|
||||
|
||||
class CrossEntropy(_Loss):
|
||||
"""the redefined loss function with SoftmaxCrossEntropyWithLogits"""
|
||||
|
||||
def __init__(self, smooth_factor=0., num_classes=1001):
|
||||
super(CrossEntropy, self).__init__()
|
||||
self.onehot = P.OneHot()
|
||||
self.on_value = Tensor(1.0 - smooth_factor, mstype.float32)
|
||||
self.off_value = Tensor(1.0 * smooth_factor / (num_classes - 1), mstype.float32)
|
||||
self.ce = nn.SoftmaxCrossEntropyWithLogits()
|
||||
self.mean = P.ReduceMean(False)
|
||||
|
||||
def construct(self, logit, label):
|
||||
one_hot_label = self.onehot(label, F.shape(logit)[1], self.on_value, self.off_value)
|
||||
loss = self.ce(logit, one_hot_label)
|
||||
loss = self.mean(loss, 0)
|
||||
return loss
|
|
@ -121,7 +121,7 @@ def create_dataset2(dataset_path, do_train, repeat_num=1, batch_size=32, target=
|
|||
else:
|
||||
trans = [
|
||||
C.Decode(),
|
||||
C.Resize((256, 256)),
|
||||
C.Resize(256),
|
||||
C.CenterCrop(image_size),
|
||||
C.Normalize(mean=mean, std=std),
|
||||
C.HWC2CHW()
|
||||
|
|
|
@ -31,7 +31,6 @@ from mindspore.communication.management import init, get_rank, get_group_size
|
|||
import mindspore.nn as nn
|
||||
import mindspore.common.initializer as weight_init
|
||||
from src.lr_generator import get_lr, warmup_cosine_annealing_lr
|
||||
from src.crossentropy import CrossEntropy
|
||||
|
||||
parser = argparse.ArgumentParser(description='Image classification')
|
||||
parser.add_argument('--net', type=str, default=None, help='Resnet Model, either resnet50 or resnet101')
|
||||
|
@ -75,7 +74,7 @@ if __name__ == '__main__':
|
|||
context.set_auto_parallel_context(device_num=args_opt.device_num, parallel_mode=ParallelMode.DATA_PARALLEL,
|
||||
mirror_mean=True)
|
||||
if args_opt.net == "resnet50":
|
||||
auto_parallel_context().set_all_reduce_fusion_split_indices([107, 160])
|
||||
auto_parallel_context().set_all_reduce_fusion_split_indices([85, 160])
|
||||
else:
|
||||
auto_parallel_context().set_all_reduce_fusion_split_indices([180, 313])
|
||||
init()
|
||||
|
@ -128,15 +127,19 @@ if __name__ == '__main__':
|
|||
lr = Tensor(lr)
|
||||
|
||||
# define opt
|
||||
opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, config.momentum,
|
||||
config.weight_decay, config.loss_scale)
|
||||
|
||||
decayed_params = list(filter(lambda x: 'beta' not in x.name and 'gamma' not in x.name and 'bias' not in x.name, net.trainalbe_params()))
|
||||
no_decayed_params = [param for param in net.trainalbe_params() if param not in decayed_params]
|
||||
group_params = [{'params': decayed_params, 'weight_decay': config.weight_decay},
|
||||
{'params': no_decayed_params},
|
||||
{'order_params': net.trainalbe_params()}]
|
||||
opt = Momentum(group_params, lr, config.momentum, loss_scale=config.loss_scale)
|
||||
# define loss, model
|
||||
if target == "Ascend":
|
||||
if args_opt.dataset == "imagenet2012":
|
||||
if not config.use_label_smooth:
|
||||
config.label_smooth_factor = 0.0
|
||||
loss = CrossEntropy(smooth_factor=config.label_smooth_factor, num_classes=config.class_num)
|
||||
loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean",
|
||||
smooth_factor=config.label_smooth_factor, num_classes=config.class_num)
|
||||
else:
|
||||
loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean')
|
||||
loss_scale = FixedLossScaleManager(config.loss_scale, drop_overflow_update=False)
|
||||
|
|
Loading…
Reference in New Issue