change _Loss to Loss

This commit is contained in:
chenhaozhe 2021-06-01 14:52:26 +08:00
parent 3fd22fde0b
commit 9da8534396
45 changed files with 131 additions and 111 deletions

View File

@ -15,6 +15,7 @@
"""loss"""
import mindspore
import mindspore.common.dtype as mstype
from mindspore import log
from mindspore.common.tensor import Tensor
from mindspore.common.parameter import Parameter
from mindspore.ops import operations as P
@ -27,13 +28,18 @@ from mindspore._checkparam import Validator as validator
from mindspore._checkparam import Rel
from ... import context
class _Loss(Cell):
class Loss(Cell):
"""
Base class for other losses.
Other losses derived from this could use method `self.get_loss` to apply reduction to loss values.
Args:
reduction (str): Type of reduction to be applied to loss. The optional values are "mean", "sum", and "none".
Default: "mean".
"""
def __init__(self, reduction='mean'):
super(_Loss, self).__init__()
super(Loss, self).__init__()
if reduction is None:
reduction = 'none'
@ -79,13 +85,27 @@ class _Loss(Cell):
def construct(self, base, target):
raise NotImplementedError
class _Loss(Loss):
"""
Base class for other losses.
"""
def __init__(self, reduction='mean'):
log.warning("'_Loss' is deprecated from version 1.3 and "
"will be removed in a future version, use 'Loss' instead.")
super(_Loss, self).__init__()
def construct(self, base, target):
raise NotImplementedError
@constexpr
def _check_input_type(param_name, input_data, allow_dtype, cls_name):
if input_data is not None and not isinstance(input_data, allow_dtype):
raise TypeError(f"For '{cls_name}', the '{param_name}' should be '{allow_dtype}', "
f"but got '{F.typeof(input_data)}'")
class L1Loss(_Loss):
class L1Loss(Loss):
r"""
L1Loss creates a criterion to measure the mean absolute error (MAE) between :math:`x` and :math:`y` element-wise,
where :math:`x` is the input Tensor and :math:`y` is the target Tensor.
@ -135,7 +155,7 @@ class L1Loss(_Loss):
return self.get_loss(x)
class MSELoss(_Loss):
class MSELoss(Loss):
r"""
MSELoss creates a criterion to measure the mean squared error (squared L2-norm) between :math:`x` and :math:`y`
element-wise, where :math:`x` is the input and :math:`y` is the target.
@ -181,7 +201,7 @@ class MSELoss(_Loss):
return self.get_loss(x)
class RMSELoss(_Loss):
class RMSELoss(Loss):
r"""
RMSELoss creates a standard to measure the root mean square error between :math:`x` and :math:`y`
element-wise, where :math:`x` is the input and :math:`y` is the target.
@ -222,7 +242,7 @@ class RMSELoss(_Loss):
return rmse_loss
class MAELoss(_Loss):
class MAELoss(Loss):
r"""
MAELoss creates a standard to measure the average absolute error between :math:`x` and :math:`y`
element-wise, where :math:`x` is the input and :math:`y` is the target.
@ -270,7 +290,7 @@ class MAELoss(_Loss):
return self.get_loss(x)
class SmoothL1Loss(_Loss):
class SmoothL1Loss(Loss):
r"""
A loss class for learning region proposals.
@ -332,7 +352,7 @@ class SmoothL1Loss(_Loss):
return self.smooth_l1_loss(base, target)
class SoftmaxCrossEntropyWithLogits(_Loss):
class SoftmaxCrossEntropyWithLogits(Loss):
r"""
Computes softmax cross entropy between logits and labels.
@ -419,7 +439,7 @@ def _check_label_dtype(labels_dtype, cls_name):
validator.check_type_name("labels", labels_dtype, [mstype.int32, mstype.int64], cls_name)
class DiceLoss(_Loss):
class DiceLoss(Loss):
r"""
The Dice coefficient is a set similarity loss. It is used to calculate the similarity between two samples. The
value of the Dice coefficient is 1 when the segmentation result is the best and 0 when the segmentation result
@ -493,7 +513,7 @@ def _check_weights(weight_shape, label_shape):
raise ValueError("The weight shape[0] should be equal to label.shape[1].")
class MultiClassDiceLoss(_Loss):
class MultiClassDiceLoss(Loss):
r"""
When there are multiple classifications, label is transformed into multiple binary classifications by one hot.
For each channel section in the channel, it can be regarded as a binary classification problem, so it can be
@ -572,7 +592,7 @@ class MultiClassDiceLoss(_Loss):
return total_loss/label.shape[1]
class SampledSoftmaxLoss(_Loss):
class SampledSoftmaxLoss(Loss):
r"""
Computes the sampled softmax training loss.
@ -795,7 +815,7 @@ class SampledSoftmaxLoss(_Loss):
return out_logits, out_labels
class BCELoss(_Loss):
class BCELoss(Loss):
r"""
BCELoss creates a criterion to measure the binary cross entropy between the true labels and predicted labels.
@ -876,7 +896,7 @@ def _check_reduced_shape_valid(ori_shape, reduced_shape, axis, cls_name):
validator.check_reduce_shape(ori_shape, reduced_shape, axis, cls_name)
class CosineEmbeddingLoss(_Loss):
class CosineEmbeddingLoss(Loss):
r"""
Computes the similarity between two tensors using cosine distance.
@ -951,7 +971,7 @@ class CosineEmbeddingLoss(_Loss):
return self.get_loss(output_unreduced)
class BCEWithLogitsLoss(_Loss):
class BCEWithLogitsLoss(Loss):
r"""
Adds sigmoid activation function to input `predict`, and uses the given logits to compute binary cross entropy
between the target and the output.
@ -1065,7 +1085,7 @@ def _check_input_dtype(targets_dtype, cls_name):
mstype.float32], cls_name)
class FocalLoss(_Loss):
class FocalLoss(Loss):
r"""
The loss function proposed by Kaiming team in their paper ``Focal Loss for Dense Object Detection`` improves the
effect of image object detection. It is a loss function to solve the imbalance of categories and the difference of

View File

@ -34,7 +34,7 @@ from mindspore.train.callback import Callback, ModelCheckpoint
from mindspore.train import lineage_pb2
from mindspore.train.callback._dataset_graph import DatasetGraph
from mindspore.nn.optim.optimizer import Optimizer
from mindspore.nn.loss.loss import _Loss
from mindspore.nn.loss.loss import Loss
from mindspore.train._utils import check_value_type
HYPER_CONFIG_ENV_NAME = "MINDINSIGHT_HYPER_CONFIG"
@ -907,7 +907,7 @@ class SummaryCollector(Callback):
network = cb_params.eval_network
for _, cell in network.cells_and_names():
if isinstance(cell, _Loss):
if isinstance(cell, Loss):
loss_fn = cell
break
return loss_fn

View File

@ -14,13 +14,13 @@
# ============================================================================
"""CTC Loss."""
import numpy as np
from mindspore.nn.loss.loss import _Loss
from mindspore.nn.loss.loss import Loss
from mindspore import Tensor, Parameter
from mindspore.common import dtype as mstype
from mindspore.ops import operations as P
class CTCLoss(_Loss):
class CTCLoss(Loss):
"""
CTCLoss definition

View File

@ -24,7 +24,7 @@ import mindspore.nn as nn
import mindspore.ops.operations as P
import mindspore.common.dtype as mstype
from mindspore import context, Tensor
from mindspore.nn.loss.loss import _Loss
from mindspore.nn.loss.loss import Loss
from mindspore.ops import composite as C
from mindspore.ops import functional as F
from mindspore.context import ParallelMode
@ -34,7 +34,7 @@ from mindspore.parallel._auto_parallel_context import auto_parallel_context
from src.seq2seq import Encoder, Decoder
class NLLLoss(_Loss):
class NLLLoss(Loss):
def __init__(self, reduction='mean'):
super(NLLLoss, self).__init__(reduction)
self.one_hot = P.OneHot()

View File

@ -13,14 +13,14 @@
# limitations under the License.
# ============================================================================
"""define loss function for network."""
from mindspore.nn.loss.loss import _Loss
from mindspore.nn.loss.loss import Loss
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore.common import dtype as mstype
from mindspore import Tensor
import mindspore.nn as nn
class LabelSmoothingCrossEntropy(_Loss):
class LabelSmoothingCrossEntropy(Loss):
def __init__(self, smooth_factor=0.1, num_classes=1000):
super(LabelSmoothingCrossEntropy, self).__init__()

View File

@ -16,12 +16,12 @@
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.common import dtype as mstype
from mindspore.nn.loss.loss import _Loss
from mindspore.nn.loss.loss import Loss
from mindspore.ops import functional as F
from mindspore.ops import operations as P
class CrossEntropySmooth(_Loss):
class CrossEntropySmooth(Loss):
"""CrossEntropy"""
def __init__(self, sparse=True, reduction='mean', smooth_factor=0., num_classes=1000):
super(CrossEntropySmooth, self).__init__()

View File

@ -16,12 +16,12 @@
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.common import dtype as mstype
from mindspore.nn.loss.loss import _Loss
from mindspore.nn.loss.loss import Loss
from mindspore.ops import functional as F
from mindspore.ops import operations as P
class CrossEntropySmooth(_Loss):
class CrossEntropySmooth(Loss):
"""CrossEntropy"""
def __init__(self, sparse=True, reduction='mean', smooth_factor=0., num_classes=1000):
super(CrossEntropySmooth, self).__init__()

View File

@ -19,12 +19,12 @@ from mindspore import nn
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore.common import dtype as mstype
from mindspore.nn.loss.loss import _Loss
from mindspore.nn.loss.loss import Loss
from mindspore.train.callback import Callback
from mindspore.train.serialization import load_checkpoint, load_param_into_net
from src.mobilenetV2 import MobileNetV2Backbone, MobileNetV2Head, mobilenet_v2
class CrossEntropyWithLabelSmooth(_Loss):
class CrossEntropyWithLabelSmooth(Loss):
"""
CrossEntropyWith LabelSmooth.

View File

@ -20,7 +20,7 @@ import numpy as np
from mindspore.train.callback import Callback
from mindspore import Tensor
from mindspore import nn
from mindspore.nn.loss.loss import _Loss
from mindspore.nn.loss.loss import Loss
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore.common import dtype as mstype
@ -78,7 +78,7 @@ class Monitor(Callback):
np.mean(self.losses), step_mseconds, self.lr_init[cb_params.cur_step_num - 1]))
class CrossEntropyWithLabelSmooth(_Loss):
class CrossEntropyWithLabelSmooth(Loss):
"""
CrossEntropyWith LabelSmooth.

View File

@ -24,7 +24,7 @@ from mindspore import Tensor
from mindspore import nn
from mindspore.nn.optim.momentum import Momentum
from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits
from mindspore.nn.loss.loss import _Loss
from mindspore.nn.loss.loss import Loss
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore.common import dtype as mstype
@ -69,7 +69,7 @@ else:
raise ValueError("Unsupported device_target.")
class CrossEntropyWithLabelSmooth(_Loss):
class CrossEntropyWithLabelSmooth(Loss):
"""
CrossEntropyWith LabelSmooth.

View File

@ -13,7 +13,7 @@
# limitations under the License.
# ============================================================================
"""define evaluation loss function for network."""
from mindspore.nn.loss.loss import _Loss
from mindspore.nn.loss.loss import Loss
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore import Tensor
@ -21,7 +21,7 @@ from mindspore.common import dtype as mstype
import mindspore.nn as nn
class CrossEntropy_Val(_Loss):
class CrossEntropy_Val(Loss):
"""the redefined loss function with SoftmaxCrossEntropyWithLogits, will be used in inference process"""
def __init__(self, smooth_factor=0, num_classes=1000):
super(CrossEntropy_Val, self).__init__()

View File

@ -17,7 +17,7 @@ import numpy as np
from mindspore import Tensor
import mindspore.nn as nn
from mindspore.nn.loss.loss import _Loss
from mindspore.nn.loss.loss import Loss
import mindspore.ops.operations as P
import mindspore.ops.functional as F
import mindspore.ops.composite as C
@ -57,7 +57,7 @@ def _clip_grad(clip_type, clip_value, grad):
return new_grad
class CrossEntropy(_Loss):
class CrossEntropy(Loss):
"""the redefined loss function with SoftmaxCrossEntropyWithLogits"""
def __init__(self, smooth_factor=0, num_classes=1000, factor=0.4):
super(CrossEntropy, self).__init__()

View File

@ -16,12 +16,12 @@
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.common import dtype as mstype
from mindspore.nn.loss.loss import _Loss
from mindspore.nn.loss.loss import Loss
from mindspore.ops import functional as F
from mindspore.ops import operations as P
class CrossEntropySmooth(_Loss):
class CrossEntropySmooth(Loss):
"""CrossEntropy"""
def __init__(self, sparse=True, reduction='mean', smooth_factor=0., num_classes=1000):
super(CrossEntropySmooth, self).__init__()

View File

@ -16,12 +16,12 @@
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.common import dtype as mstype
from mindspore.nn.loss.loss import _Loss
from mindspore.nn.loss.loss import Loss
from mindspore.ops import functional as F
from mindspore.ops import operations as P
class CrossEntropySmooth(_Loss):
class CrossEntropySmooth(Loss):
"""CrossEntropy"""
def __init__(self, sparse=True, reduction='mean', smooth_factor=0., num_classes=1000):
super(CrossEntropySmooth, self).__init__()

View File

@ -13,7 +13,7 @@
# limitations under the License.
# ============================================================================
"""define loss function for network"""
from mindspore.nn.loss.loss import _Loss
from mindspore.nn.loss.loss import Loss
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore import Tensor
@ -21,7 +21,7 @@ from mindspore.common import dtype as mstype
import mindspore.nn as nn
class CrossEntropy(_Loss):
class CrossEntropy(Loss):
"""the redefined loss function with SoftmaxCrossEntropyWithLogits"""
def __init__(self, smooth_factor=0, num_classes=1001):

View File

@ -15,14 +15,14 @@
"""
define loss function for network.
"""
from mindspore.nn.loss.loss import _Loss
from mindspore.nn.loss.loss import Loss
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore import Tensor
from mindspore.common import dtype as mstype
import mindspore.nn as nn
class CrossEntropy(_Loss):
class CrossEntropy(Loss):
"""
the redefined loss function with SoftmaxCrossEntropyWithLogits.
"""

View File

@ -15,14 +15,14 @@
"""
define loss function for network.
"""
from mindspore.nn.loss.loss import _Loss
from mindspore.nn.loss.loss import Loss
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore import Tensor
from mindspore.common import dtype as mstype
import mindspore.nn as nn
class CrossEntropy(_Loss):
class CrossEntropy(Loss):
"""
the redefined loss function with SoftmaxCrossEntropyWithLogits.
"""

View File

@ -16,12 +16,12 @@
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.common import dtype as mstype
from mindspore.nn.loss.loss import _Loss
from mindspore.nn.loss.loss import Loss
from mindspore.ops import functional as F
from mindspore.ops import operations as P
class CrossEntropySmooth(_Loss):
class CrossEntropySmooth(Loss):
"""CrossEntropy"""
def __init__(self, sparse=True, reduction='mean', smooth_factor=0., num_classes=1000):
super(CrossEntropySmooth, self).__init__()

View File

@ -15,11 +15,11 @@
import mindspore.nn as nn
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore.nn.loss.loss import _Loss
from mindspore.nn.loss.loss import Loss
from mindspore.common import dtype as mstype
class JointsMSELoss(_Loss):
class JointsMSELoss(Loss):
def __init__(self, use_target_weight):
super(JointsMSELoss, self).__init__()
self.criterion = nn.MSELoss(reduction='mean')

View File

@ -16,12 +16,12 @@
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.common import dtype as mstype
from mindspore.nn.loss.loss import _Loss
from mindspore.nn.loss.loss import Loss
from mindspore.ops import functional as F
from mindspore.ops import operations as P
class CrossEntropySmooth(_Loss):
class CrossEntropySmooth(Loss):
"""CrossEntropy"""
def __init__(self, sparse=True, reduction='mean', smooth_factor=0., num_classes=1000):
super(CrossEntropySmooth, self).__init__()

View File

@ -16,10 +16,10 @@
import mindspore.nn as nn
from mindspore import dtype as mstype
from mindspore.ops import operations as P
from mindspore.nn.loss.loss import _Loss
from mindspore.nn.loss.loss import Loss
from src.model_utils.config import config
class SoftmaxCrossEntropyWithLogits(_Loss):
class SoftmaxCrossEntropyWithLogits(Loss):
def __init__(self):
super(SoftmaxCrossEntropyWithLogits, self).__init__()
self.transpose = P.Transpose()

View File

@ -13,7 +13,7 @@
# limitations under the License.
# ============================================================================
"""define loss function for network"""
from mindspore.nn.loss.loss import _Loss
from mindspore.nn.loss.loss import Loss
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore import Tensor
@ -21,7 +21,7 @@ from mindspore.common import dtype as mstype
import mindspore.nn as nn
class CrossEntropy(_Loss):
class CrossEntropy(Loss):
"""the redefined loss function with SoftmaxCrossEntropyWithLogits"""
def __init__(self, smooth_factor=0., num_classes=1001):

View File

@ -14,13 +14,13 @@
# ============================================================================
"""CTC Loss."""
import numpy as np
from mindspore.nn.loss.loss import _Loss
from mindspore.nn.loss.loss import Loss
from mindspore import Tensor, Parameter
from mindspore.common import dtype as mstype
from mindspore.ops import operations as P
class CTCLoss(_Loss):
class CTCLoss(Loss):
"""
CTCLoss definition

View File

@ -14,10 +14,10 @@
# ============================================================================
"""NLLLoss cell"""
import mindspore.ops.operations as P
from mindspore.nn.loss.loss import _Loss
from mindspore.nn.loss.loss import Loss
from mindspore.ops import functional as F
class NLLLoss(_Loss):
class NLLLoss(Loss):
'''
NLLLoss function
'''

View File

@ -17,7 +17,7 @@ import mindspore.nn as nn
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore.common import dtype as mstype
from mindspore.nn.loss.loss import _Loss
from mindspore.nn.loss.loss import Loss
from mindspore import Tensor
eps = 1e-24
@ -41,7 +41,7 @@ class log_softmax(nn.Cell):
return result
class CEWithIgnoreIndex3D(_Loss):
class CEWithIgnoreIndex3D(Loss):
'''CEWithIgnoreIndex3D'''
def __init__(self):
super(CEWithIgnoreIndex3D, self).__init__()

View File

@ -13,7 +13,7 @@
# limitations under the License.
# ============================================================================
"""Face Recognition loss."""
from mindspore.nn.loss.loss import _Loss
from mindspore.nn.loss.loss import Loss
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore import Tensor
@ -23,7 +23,7 @@ import mindspore.nn as nn
eps = 1e-24
class CrossEntropyNew(_Loss):
class CrossEntropyNew(Loss):
'''CrossEntropyNew'''
def __init__(self, smooth_factor=0., num_classes=1000):
super(CrossEntropyNew, self).__init__()
@ -42,7 +42,7 @@ class CrossEntropyNew(_Loss):
return loss
class CrossEntropy(_Loss):
class CrossEntropy(Loss):
'''CrossEntropy'''
def __init__(self):
super(CrossEntropy, self).__init__()
@ -106,7 +106,7 @@ class CrossEntropyWithIgnoreIndex(nn.Cell):
eps = 1e-24
class CEWithIgnoreIndex3D(_Loss):
class CEWithIgnoreIndex3D(Loss):
'''CEWithIgnoreIndex3D'''
def __init__(self):
super(CEWithIgnoreIndex3D, self).__init__()

View File

@ -16,12 +16,12 @@
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.common import dtype as mstype
from mindspore.nn.loss.loss import _Loss
from mindspore.nn.loss.loss import Loss
from mindspore.ops import functional as F
from mindspore.ops import operations as P
class CrossEntropySmooth(_Loss):
class CrossEntropySmooth(Loss):
"""CrossEntropy"""
def __init__(self, sparse=True, reduction='mean', smooth_factor=0., num_classes=1000):
super(CrossEntropySmooth, self).__init__()

View File

@ -14,7 +14,7 @@
# ============================================================================
"""define loss function for network."""
from mindspore.nn.loss.loss import _Loss
from mindspore.nn.loss.loss import Loss
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore import Tensor
@ -22,7 +22,7 @@ from mindspore.common import dtype as mstype
import mindspore.nn as nn
class LabelSmoothingCrossEntropy(_Loss):
class LabelSmoothingCrossEntropy(Loss):
"""cross-entropy with label smoothing"""
def __init__(self, smooth_factor=0.1, num_classes=1000):

View File

@ -16,7 +16,7 @@
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.common import dtype as mstype
from mindspore.nn.loss.loss import _Loss
from mindspore.nn.loss.loss import Loss
from mindspore.ops import functional as F
from mindspore.ops import operations as P
import mindspore.ops as ops
@ -57,7 +57,7 @@ class SoftmaxCrossEntropyExpand(nn.Cell): # pylint: disable=missing-docstring
return loss
class CrossEntropySmooth(_Loss):
class CrossEntropySmooth(Loss):
"""CrossEntropy"""
def __init__(self, sparse=True, reduction='mean', smooth_factor=0., num_classes=1000):

View File

@ -16,12 +16,12 @@
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.common import dtype as mstype
from mindspore.nn.loss.loss import _Loss
from mindspore.nn.loss.loss import Loss
from mindspore.ops import functional as F
from mindspore.ops import operations as P
class CrossEntropySmooth(_Loss):
class CrossEntropySmooth(Loss):
"""CrossEntropy"""
def __init__(self, sparse=True, reduction='mean', smooth_factor=0., num_classes=1000):
super(CrossEntropySmooth, self).__init__()

View File

@ -14,7 +14,7 @@
# ============================================================================
"""define loss function for network."""
from mindspore.nn.loss.loss import _Loss
from mindspore.nn.loss.loss import Loss
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore import Tensor
@ -22,7 +22,7 @@ from mindspore.common import dtype as mstype
import mindspore.nn as nn
class LabelSmoothingCrossEntropy(_Loss):
class LabelSmoothingCrossEntropy(Loss):
"""cross-entropy with label smoothing"""
def __init__(self, smooth_factor=0.1, num_classes=1000):

View File

@ -16,12 +16,12 @@
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.common import dtype as mstype
from mindspore.nn.loss.loss import _Loss
from mindspore.nn.loss.loss import Loss
from mindspore.ops import functional as F
from mindspore.ops import operations as P
class CrossEntropySmooth(_Loss):
class CrossEntropySmooth(Loss):
"""CrossEntropy"""
def __init__(self, sparse=True, reduction='mean', smooth_factor=0., num_classes=1000):
super(CrossEntropySmooth, self).__init__()

View File

@ -14,7 +14,7 @@
# ============================================================================
"""define loss function for network."""
from mindspore.nn.loss.loss import _Loss
from mindspore.nn.loss.loss import Loss
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore import Tensor
@ -22,7 +22,7 @@ from mindspore.common import dtype as mstype
import mindspore.nn as nn
class LabelSmoothingCrossEntropy(_Loss):
class LabelSmoothingCrossEntropy(Loss):
"""cross-entropy with label smoothing"""
def __init__(self, smooth_factor=0.1, num_classes=1000):

View File

@ -13,7 +13,7 @@
# limitations under the License.
# ============================================================================
"""define loss function for network"""
from mindspore.nn.loss.loss import _Loss
from mindspore.nn.loss.loss import Loss
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore import Tensor
@ -21,7 +21,7 @@ from mindspore.common import dtype as mstype
import mindspore.nn as nn
class CrossEntropy(_Loss):
class CrossEntropy(Loss):
"""the redefined loss function with SoftmaxCrossEntropyWithLogits"""
def __init__(self, smooth_factor=0., num_classes=1001):

View File

@ -14,14 +14,14 @@
# ===========================================================================
"""DSCNN loss."""
import mindspore.nn as nn
from mindspore.nn.loss.loss import _Loss
from mindspore.nn.loss.loss import Loss
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore import Tensor
from mindspore.common import dtype as mstype
class CrossEntropy(_Loss):
class CrossEntropy(Loss):
'''Build CrossEntropy Loss.'''
def __init__(self, smooth_factor=0., num_classes=1000):
super(CrossEntropy, self).__init__()

View File

@ -23,7 +23,7 @@ import mindspore.ops.functional as F
from mindspore import Tensor
from mindspore.common.initializer import TruncatedNormal
from mindspore.communication.management import init
from mindspore.nn.loss.loss import _Loss
from mindspore.nn.loss.loss import Loss
from mindspore.nn.optim.momentum import Momentum
from mindspore.ops import operations as P
from mindspore.parallel import set_algo_parameters
@ -245,7 +245,7 @@ def resnet50(class_num=10):
class_num)
class SoftmaxCrossEntropyExpand(_Loss):
class SoftmaxCrossEntropyExpand(Loss):
def __init__(self, sparse=False):
super(SoftmaxCrossEntropyExpand, self).__init__()
self.exp = P.Exp()
@ -307,15 +307,15 @@ class DataGenerator():
data = (self.generate_data(shape)).astype(np.float32)
stra = [1] * len(shape)
stra[0] = device_num
datas = self.get_parallel_blocks(data, stra)
return Tensor(data), Tensor(datas[rank_id])
data_parallel = self.get_parallel_blocks(data, stra)
return Tensor(data), Tensor(data_parallel[rank_id])
def label_data(self, shape):
data = (self.generate_data(shape) * 1000 / np.prod(shape)).astype(np.int32)
stra = [1] * len(shape)
stra[0] = device_num
datas = self.get_parallel_blocks(data, stra)
return Tensor(data), Tensor(datas[rank_id])
data_parallel = self.get_parallel_blocks(data, stra)
return Tensor(data), Tensor(data_parallel[rank_id])
class Dataset():

View File

@ -16,12 +16,12 @@
import mindspore.nn as nn
from mindspore import Tensor
from mindspore.common import dtype as mstype
from mindspore.nn.loss.loss import _Loss
from mindspore.nn.loss.loss import Loss
from mindspore.ops import functional as F
from mindspore.ops import operations as P
class CrossEntropySmooth(_Loss):
class CrossEntropySmooth(Loss):
"""CrossEntropy"""
def __init__(self, sparse=True, reduction='mean', smooth_factor=0., num_classes=1000):
super(CrossEntropySmooth, self).__init__()

View File

@ -18,11 +18,11 @@ import pytest
from mindspore import Tensor
from mindspore.ops import operations as P
from mindspore.nn.loss.loss import _Loss
from mindspore.nn.loss.loss import Loss
from mindspore.nn.loss.loss import L1Loss
import mindspore.context as context
class WeightedLoss(_Loss):
class WeightedLoss(Loss):
def __init__(self, reduction='mean', weights=1.0):
super(WeightedLoss, self).__init__(reduction)
self.abs = P.Abs()
@ -72,7 +72,7 @@ def test_weighted_loss_float32():
def test_weighted_loss_float64():
weighted_loss(np.float64)
class CustomLoss(_Loss):
class CustomLoss(Loss):
def __init__(self, reduction='mean'):
super(CustomLoss, self).__init__(reduction)
self.abs = P.Abs()

View File

@ -142,7 +142,7 @@ def test_sampledsoftmaxloss_reduction_invalid():
with pytest.raises(ValueError):
nn.SampledSoftmaxLoss(num_sampled=4, num_classes=7, reduction="invalid")
# reduction can be None, as defined in _Loss
# reduction can be None, as defined in Loss
# with pytest.raises(ValueError):
# nn.SampledSoftmaxLoss(num_sampled=4, num_classes=7, reduction=None) #

View File

@ -20,7 +20,7 @@ import numpy as np
from mindspore.train.callback import Callback
from mindspore import Tensor
from mindspore import nn
from mindspore.nn.loss.loss import _Loss
from mindspore.nn.loss.loss import Loss
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore.common import dtype as mstype
@ -87,7 +87,7 @@ class Monitor(Callback):
run_context.request_stop()
class CrossEntropyWithLabelSmooth(_Loss):
class CrossEntropyWithLabelSmooth(Loss):
"""
CrossEntropyWith LabelSmooth.

View File

@ -20,7 +20,7 @@ import numpy as np
from mindspore.train.callback import Callback
from mindspore import Tensor
from mindspore import nn
from mindspore.nn.loss.loss import _Loss
from mindspore.nn.loss.loss import Loss
from mindspore.ops import operations as P
from mindspore.ops import functional as F
from mindspore.common import dtype as mstype
@ -85,7 +85,7 @@ class Monitor(Callback):
run_context.request_stop()
class CrossEntropy(_Loss):
class CrossEntropy(Loss):
"""the redefined loss function with SoftmaxCrossEntropyWithLogits"""
def __init__(self, smooth_factor=0, num_classes=1001):

View File

@ -23,7 +23,7 @@ from mindspore import context
from mindspore.common.api import _executor
from mindspore.common.initializer import TruncatedNormal
from mindspore.communication.management import init
from mindspore.nn.loss.loss import _Loss
from mindspore.nn.loss.loss import Loss
from mindspore.nn.optim.momentum import Momentum
from mindspore.ops import operations as P
from mindspore.parallel import _cost_model_context as cost_model_context
@ -215,7 +215,7 @@ def resnet50(class_num=10):
class_num)
class SoftmaxCrossEntropyExpand(_Loss):
class SoftmaxCrossEntropyExpand(Loss):
def __init__(self, sparse=False):
super(SoftmaxCrossEntropyExpand, self).__init__()
self.exp = P.Exp()

View File

@ -20,7 +20,7 @@ from mindspore import Tensor, Parameter
from mindspore import context
from mindspore.common import dtype as mstype
from mindspore.common.api import _executor
from mindspore.nn.loss.loss import _Loss
from mindspore.nn.loss.loss import Loss
from mindspore.ops import composite as C
from mindspore.ops import operations as P
from tests.ut.python.ops.test_math_ops import VirtualLoss
@ -59,7 +59,7 @@ class CustomMatMul(nn.Cell):
return out
class MarginCE(_Loss):
class MarginCE(Loss):
def __init__(self):
super(MarginCE, self).__init__()
self.fc = CustomMatMul(transpose_b=True)

View File

@ -22,7 +22,7 @@ from mindspore.common import dtype as mstype
from mindspore.common.parameter import ParameterTuple
from mindspore.communication.management import init
from mindspore.nn import Dense, Cell
from mindspore.nn.loss.loss import _Loss
from mindspore.nn.loss.loss import Loss
from mindspore.nn.optim import Momentum
from mindspore.ops import composite as C
from mindspore.ops import functional as F
@ -64,7 +64,7 @@ class Dataset():
return self
class GatherV2(_Loss):
class GatherV2(Loss):
def __init__(self, index_dim, strategy, index_size=16):
super(GatherV2, self).__init__()
self.pow = P.Pow()
@ -195,7 +195,7 @@ def test_strategy3():
net_trains(criterion, rank)
class GatherV2Axis1(_Loss):
class GatherV2Axis1(Loss):
def __init__(self, index_dim, strategy, index_size=16):
super(GatherV2Axis1, self).__init__()
self.pow = P.Pow()

View File

@ -22,7 +22,7 @@ from mindspore import context
import mindspore.common.dtype as mstype
from mindspore.common.api import _executor
from mindspore.common.parameter import Parameter
from mindspore.nn.loss.loss import _Loss
from mindspore.nn.loss.loss import Loss
from mindspore.nn.optim.momentum import Momentum
from mindspore.ops import operations as P
from mindspore.ops import functional as F
@ -68,7 +68,7 @@ class AllToAllNet(nn.Cell):
return x
class SoftmaxCrossEntropyWithLogits(_Loss):
class SoftmaxCrossEntropyWithLogits(Loss):
def __init__(self,
sparse=False,
reduction='none'):