diff --git a/model_zoo/official/cv/resnet/eval.py b/model_zoo/official/cv/resnet/eval.py index 426b8c9f3de..e0925346c78 100755 --- a/model_zoo/official/cv/resnet/eval.py +++ b/model_zoo/official/cv/resnet/eval.py @@ -22,7 +22,6 @@ from mindspore import dataset as de from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits from mindspore.train.model import Model from mindspore.train.serialization import load_checkpoint, load_param_into_net -from src.crossentropy import CrossEntropy parser = argparse.ArgumentParser(description='Image classification') parser.add_argument('--net', type=str, default=None, help='Resnet Model, either resnet50 or resnet101') @@ -78,7 +77,8 @@ if __name__ == '__main__': if args_opt.dataset == "imagenet2012": if not config.use_label_smooth: config.label_smooth_factor = 0.0 - loss = CrossEntropy(smooth_factor=config.label_smooth_factor, num_classes=config.class_num) + loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean", + smooth_factor=config.label_smooth_factor, num_classes=config.class_num) else: loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean') diff --git a/model_zoo/official/cv/resnet/src/crossentropy.py b/model_zoo/official/cv/resnet/src/crossentropy.py deleted file mode 100755 index 5118cb51612..00000000000 --- a/model_zoo/official/cv/resnet/src/crossentropy.py +++ /dev/null @@ -1,39 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""define loss function for network""" -from mindspore.nn.loss.loss import _Loss -from mindspore.ops import operations as P -from mindspore.ops import functional as F -from mindspore import Tensor -from mindspore.common import dtype as mstype -import mindspore.nn as nn - - -class CrossEntropy(_Loss): - """the redefined loss function with SoftmaxCrossEntropyWithLogits""" - - def __init__(self, smooth_factor=0., num_classes=1001): - super(CrossEntropy, self).__init__() - self.onehot = P.OneHot() - self.on_value = Tensor(1.0 - smooth_factor, mstype.float32) - self.off_value = Tensor(1.0 * smooth_factor / (num_classes - 1), mstype.float32) - self.ce = nn.SoftmaxCrossEntropyWithLogits() - self.mean = P.ReduceMean(False) - - def construct(self, logit, label): - one_hot_label = self.onehot(label, F.shape(logit)[1], self.on_value, self.off_value) - loss = self.ce(logit, one_hot_label) - loss = self.mean(loss, 0) - return loss diff --git a/model_zoo/official/cv/resnet/src/dataset.py b/model_zoo/official/cv/resnet/src/dataset.py index ac0adc4bc97..1cbe5e20f3b 100755 --- a/model_zoo/official/cv/resnet/src/dataset.py +++ b/model_zoo/official/cv/resnet/src/dataset.py @@ -121,7 +121,7 @@ def create_dataset2(dataset_path, do_train, repeat_num=1, batch_size=32, target= else: trans = [ C.Decode(), - C.Resize((256, 256)), + C.Resize(256), C.CenterCrop(image_size), C.Normalize(mean=mean, std=std), C.HWC2CHW() diff --git a/model_zoo/official/cv/resnet/train.py b/model_zoo/official/cv/resnet/train.py index 7344ff08759..5f4a1423837 100755 --- a/model_zoo/official/cv/resnet/train.py +++ b/model_zoo/official/cv/resnet/train.py @@ -31,7 +31,6 @@ from mindspore.communication.management import init, get_rank, get_group_size import mindspore.nn as nn import mindspore.common.initializer as weight_init from src.lr_generator import get_lr, warmup_cosine_annealing_lr -from src.crossentropy import CrossEntropy parser = argparse.ArgumentParser(description='Image classification') parser.add_argument('--net', type=str, default=None, help='Resnet Model, either resnet50 or resnet101') @@ -75,7 +74,7 @@ if __name__ == '__main__': context.set_auto_parallel_context(device_num=args_opt.device_num, parallel_mode=ParallelMode.DATA_PARALLEL, mirror_mean=True) if args_opt.net == "resnet50": - auto_parallel_context().set_all_reduce_fusion_split_indices([107, 160]) + auto_parallel_context().set_all_reduce_fusion_split_indices([85, 160]) else: auto_parallel_context().set_all_reduce_fusion_split_indices([180, 313]) init() @@ -128,15 +127,19 @@ if __name__ == '__main__': lr = Tensor(lr) # define opt - opt = Momentum(filter(lambda x: x.requires_grad, net.get_parameters()), lr, config.momentum, - config.weight_decay, config.loss_scale) - + decayed_params = list(filter(lambda x: 'beta' not in x.name and 'gamma' not in x.name and 'bias' not in x.name, net.trainalbe_params())) + no_decayed_params = [param for param in net.trainalbe_params() if param not in decayed_params] + group_params = [{'params': decayed_params, 'weight_decay': config.weight_decay}, + {'params': no_decayed_params}, + {'order_params': net.trainalbe_params()}] + opt = Momentum(group_params, lr, config.momentum, loss_scale=config.loss_scale) # define loss, model if target == "Ascend": if args_opt.dataset == "imagenet2012": if not config.use_label_smooth: config.label_smooth_factor = 0.0 - loss = CrossEntropy(smooth_factor=config.label_smooth_factor, num_classes=config.class_num) + loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction="mean", + smooth_factor=config.label_smooth_factor, num_classes=config.class_num) else: loss = SoftmaxCrossEntropyWithLogits(sparse=True, reduction='mean') loss_scale = FixedLossScaleManager(config.loss_scale, drop_overflow_update=False)