diff --git a/mindspore/nn/loss/loss.py b/mindspore/nn/loss/loss.py index 13a424f8e11..42d65f87652 100644 --- a/mindspore/nn/loss/loss.py +++ b/mindspore/nn/loss/loss.py @@ -29,7 +29,7 @@ class _Loss(Cell): """ Base class for other losses. """ - def __init__(self, reduction='mean'): + def __init__(self, reduction='mean', weights=1.0): super(_Loss, self).__init__() if reduction is None: reduction = 'none' @@ -46,6 +46,11 @@ class _Loss(Cell): self.reduce_mean = _selected_ops.ReduceMean() self.reduce_sum = P.ReduceSum() + self.mul = P.Mul() + if isinstance(weights, int): + self.weights = float(weights) + else: + self.weights = weights def get_axis(self, x): shape = F.shape(x) @@ -54,6 +59,8 @@ class _Loss(Cell): return perm def get_loss(self, x): + if self.weights != 1.0: + x = self.mul(self.weights, x) if self.reduce and self.average: x = self.reduce_mean(x, self.get_axis(x)) if self.reduce and not self.average: diff --git a/mindspore/nn/optim/rmsprop.py b/mindspore/nn/optim/rmsprop.py index 5d364cd9eb0..3b4d9110559 100644 --- a/mindspore/nn/optim/rmsprop.py +++ b/mindspore/nn/optim/rmsprop.py @@ -131,7 +131,7 @@ class RMSProp(Optimizer): Tensor[bool], the value is True. Supported Platforms: - ``Ascend`` ``GPU`` + ``Ascend`` ``GPU`` ``CPU`` Examples: >>> net = Net() diff --git a/tests/st/ops/gpu/test_loss.py b/tests/st/ops/gpu/test_loss.py new file mode 100644 index 00000000000..693f8f01bd9 --- /dev/null +++ b/tests/st/ops/gpu/test_loss.py @@ -0,0 +1,51 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +""" test loss """ +import numpy as np + +from mindspore import Tensor +from mindspore.ops import operations as P +from mindspore.nn.loss.loss import _Loss + +class WeightedLoss(_Loss): + def __init__(self, reduction='mean', weights=1.0): + super(WeightedLoss, self).__init__(reduction, weights) + self.abs = P.Abs() + + def construct(self, base, target): + x = self.abs(base - target) + return self.get_loss(x) + +def test_WeightedLoss(): + loss = WeightedLoss() + input_data = Tensor(np.array([[1, 2, 3], [2, 3, 4]]).astype(np.float32)) + target_data = Tensor(np.array([[0, 2, 5], [3, 1, 1]]).astype(np.float32)) + output_data = loss(input_data, target_data) + + error_range = np.ones(shape=output_data.shape) * 10e-6 + loss.weights = 1.0 + test_output = loss(input_data, target_data) + diff = test_output - output_data * loss.weights + assert np.all(abs(diff.asnumpy()) < error_range) + + loss.weights = 2.0 + test_output = loss(input_data, target_data) + diff = test_output - output_data * loss.weights + assert np.all(abs(diff.asnumpy()) < error_range) + + loss.weights = 3 + test_output = loss(input_data, target_data) + diff = test_output - output_data * loss.weights + assert np.all(abs(diff.asnumpy()) < error_range)