!19246 Change Loss to LossBase

Merge pull request !19246 from chenhaozhe/change-loss-base
This commit is contained in:
i-robot 2021-07-05 01:24:23 +00:00 committed by Gitee
commit b93907fdca
46 changed files with 111 additions and 111 deletions

View File

@ -19,13 +19,13 @@ Cells of loss function. Loss function in machine learning is the target of the m
It shows how well the model works on a dataset and the optimization target which the optimizer is searching. It shows how well the model works on a dataset and the optimization target which the optimizer is searching.
""" """
from .loss import Loss, L1Loss, MSELoss, SmoothL1Loss, FocalLoss,\ from .loss import LossBase, L1Loss, MSELoss, SmoothL1Loss, FocalLoss,\
SoftmaxCrossEntropyWithLogits, BCELoss, CosineEmbeddingLoss, \ SoftmaxCrossEntropyWithLogits, BCELoss, CosineEmbeddingLoss, \
SampledSoftmaxLoss, DiceLoss, BCEWithLogitsLoss, MultiClassDiceLoss,\ SampledSoftmaxLoss, DiceLoss, BCEWithLogitsLoss, MultiClassDiceLoss,\
RMSELoss, MAELoss RMSELoss, MAELoss
__all__ = ['Loss', 'L1Loss', 'MSELoss', 'SmoothL1Loss', 'FocalLoss', __all__ = ['LossBase', 'L1Loss', 'MSELoss', 'SmoothL1Loss', 'FocalLoss',
'SoftmaxCrossEntropyWithLogits', 'BCELoss', 'BCEWithLogitsLoss', 'SoftmaxCrossEntropyWithLogits', 'BCELoss', 'BCEWithLogitsLoss',
'CosineEmbeddingLoss', 'SampledSoftmaxLoss', 'DiceLoss', 'MultiClassDiceLoss', 'CosineEmbeddingLoss', 'SampledSoftmaxLoss', 'DiceLoss', 'MultiClassDiceLoss',
'RMSELoss', 'MAELoss'] 'RMSELoss', 'MAELoss']

View File

@ -28,7 +28,7 @@ from mindspore._checkparam import Validator as validator
from mindspore._checkparam import Rel from mindspore._checkparam import Rel
from ... import context from ... import context
class Loss(Cell): class LossBase(Cell):
""" """
Base class for other losses. Base class for other losses.
@ -47,7 +47,7 @@ class Loss(Cell):
""" """
def __init__(self, reduction='mean'): def __init__(self, reduction='mean'):
"""Initialize Loss.""" """Initialize Loss."""
super(Loss, self).__init__() super(LossBase, self).__init__()
if reduction not in ('mean', 'sum', 'none'): if reduction not in ('mean', 'sum', 'none'):
raise ValueError(f"reduction method for {reduction} is not supported") raise ValueError(f"reduction method for {reduction} is not supported")
@ -92,15 +92,15 @@ class Loss(Cell):
raise NotImplementedError raise NotImplementedError
class _Loss(Loss): class _Loss(LossBase):
""" """
Base class for other losses. Base class for other losses.
""" """
def __init__(self, reduction='mean'): def __init__(self, reduction='mean'):
"""Initialize _Loss.""" """Initialize _Loss."""
log.warning("'_Loss' is deprecated from version 1.3 and " log.warning("'_Loss' is deprecated from version 1.3 and "
"will be removed in a future version, use 'Loss' instead.") "will be removed in a future version, use 'LossBase' instead.")
super(_Loss, self).__init__() super(_Loss, self).__init__(reduction)
def construct(self, base, target): def construct(self, base, target):
raise NotImplementedError raise NotImplementedError
@ -112,7 +112,7 @@ def _check_is_tensor(param_name, input_data, cls_name):
raise TypeError(f"For '{cls_name}', the '{param_name}' should be '{mstype.tensor_type}', " raise TypeError(f"For '{cls_name}', the '{param_name}' should be '{mstype.tensor_type}', "
f"but got '{F.typeof(input_data)}'") f"but got '{F.typeof(input_data)}'")
class L1Loss(Loss): class L1Loss(LossBase):
r""" r"""
L1Loss creates a criterion to measure the mean absolute error (MAE) between :math:`x` and :math:`y` element-wise, L1Loss creates a criterion to measure the mean absolute error (MAE) between :math:`x` and :math:`y` element-wise,
where :math:`x` is the input Tensor and :math:`y` is the target Tensor. where :math:`x` is the input Tensor and :math:`y` is the target Tensor.
@ -182,7 +182,7 @@ class L1Loss(Loss):
return self.get_loss(x) return self.get_loss(x)
class MSELoss(Loss): class MSELoss(LossBase):
r""" r"""
MSELoss creates a criterion to measure the mean squared error (squared L2-norm) between :math:`x` and :math:`y` MSELoss creates a criterion to measure the mean squared error (squared L2-norm) between :math:`x` and :math:`y`
element-wise, where :math:`x` is the input and :math:`y` is the target. element-wise, where :math:`x` is the input and :math:`y` is the target.
@ -247,7 +247,7 @@ class MSELoss(Loss):
return self.get_loss(x) return self.get_loss(x)
class RMSELoss(Loss): class RMSELoss(LossBase):
r""" r"""
RMSELoss creates a criterion to measure the root mean square error between :math:`x` and :math:`y` RMSELoss creates a criterion to measure the root mean square error between :math:`x` and :math:`y`
element-wise, where :math:`x` is the input and :math:`y` is the target. element-wise, where :math:`x` is the input and :math:`y` is the target.
@ -299,7 +299,7 @@ class RMSELoss(Loss):
return rmse_loss return rmse_loss
class MAELoss(Loss): class MAELoss(LossBase):
r""" r"""
MAELoss creates a criterion to measure the average absolute error between :math:`x` and :math:`y` MAELoss creates a criterion to measure the average absolute error between :math:`x` and :math:`y`
element-wise, where :math:`x` is the input and :math:`y` is the target. element-wise, where :math:`x` is the input and :math:`y` is the target.
@ -369,7 +369,7 @@ class MAELoss(Loss):
return self.get_loss(x) return self.get_loss(x)
class SmoothL1Loss(Loss): class SmoothL1Loss(LossBase):
r""" r"""
A loss class for learning region proposals. A loss class for learning region proposals.
@ -434,7 +434,7 @@ class SmoothL1Loss(Loss):
return self.smooth_l1_loss(base, target) return self.smooth_l1_loss(base, target)
class SoftmaxCrossEntropyWithLogits(Loss): class SoftmaxCrossEntropyWithLogits(LossBase):
r""" r"""
Computes softmax cross entropy between logits and labels. Computes softmax cross entropy between logits and labels.
@ -527,7 +527,7 @@ def _check_label_dtype(labels_dtype, cls_name):
validator.check_type_name("labels", labels_dtype, [mstype.int32, mstype.int64], cls_name) validator.check_type_name("labels", labels_dtype, [mstype.int32, mstype.int64], cls_name)
class DiceLoss(Loss): class DiceLoss(LossBase):
r""" r"""
The Dice coefficient is a set similarity loss. It is used to calculate the similarity between two samples. The The Dice coefficient is a set similarity loss. It is used to calculate the similarity between two samples. The
value of the Dice coefficient is 1 when the segmentation result is the best and 0 when the segmentation result value of the Dice coefficient is 1 when the segmentation result is the best and 0 when the segmentation result
@ -604,7 +604,7 @@ def _check_weights(weight_shape, label_shape):
raise ValueError("The weight shape[0] should be equal to label.shape[1].") raise ValueError("The weight shape[0] should be equal to label.shape[1].")
class MultiClassDiceLoss(Loss): class MultiClassDiceLoss(LossBase):
r""" r"""
When there are multiple classifications, label is transformed into multiple binary classifications by one hot. When there are multiple classifications, label is transformed into multiple binary classifications by one hot.
For each channel section in the channel, it can be regarded as a binary classification problem, so it can be For each channel section in the channel, it can be regarded as a binary classification problem, so it can be
@ -684,7 +684,7 @@ class MultiClassDiceLoss(Loss):
return total_loss/label.shape[1] return total_loss/label.shape[1]
class SampledSoftmaxLoss(Loss): class SampledSoftmaxLoss(LossBase):
r""" r"""
Computes the sampled softmax training loss. This operator can accelerate the trainging of the softmax classifier Computes the sampled softmax training loss. This operator can accelerate the trainging of the softmax classifier
over a large number of classes. over a large number of classes.
@ -904,7 +904,7 @@ class SampledSoftmaxLoss(Loss):
return out_logits, out_labels return out_logits, out_labels
class BCELoss(Loss): class BCELoss(LossBase):
r""" r"""
BCELoss creates a criterion to measure the binary cross entropy between the true labels and predicted labels. BCELoss creates a criterion to measure the binary cross entropy between the true labels and predicted labels.
@ -987,7 +987,7 @@ def _check_reduced_shape_valid(ori_shape, reduced_shape, axis, cls_name):
validator.check_reduce_shape(ori_shape, reduced_shape, axis, cls_name) validator.check_reduce_shape(ori_shape, reduced_shape, axis, cls_name)
class CosineEmbeddingLoss(Loss): class CosineEmbeddingLoss(LossBase):
r""" r"""
CosineEmbeddingLoss creates a criterion to measure the similarity between two tensors using cosine distance. CosineEmbeddingLoss creates a criterion to measure the similarity between two tensors using cosine distance.
@ -1064,7 +1064,7 @@ class CosineEmbeddingLoss(Loss):
return self.get_loss(output_unreduced) return self.get_loss(output_unreduced)
class BCEWithLogitsLoss(Loss): class BCEWithLogitsLoss(LossBase):
r""" r"""
Adds sigmoid activation function to input logits, and uses the given logits to compute binary cross entropy Adds sigmoid activation function to input logits, and uses the given logits to compute binary cross entropy
between the labels and the output. between the labels and the output.
@ -1181,7 +1181,7 @@ def _check_input_dtype(targets_dtype, cls_name):
mstype.float32], cls_name) mstype.float32], cls_name)
class FocalLoss(Loss): class FocalLoss(LossBase):
r""" r"""
The loss function proposed by Kaiming team in their paper ``Focal Loss for Dense Object Detection`` improves the The loss function proposed by Kaiming team in their paper ``Focal Loss for Dense Object Detection`` improves the
effect of image object detection. It is a loss function to solve the imbalance of categories and the difference of effect of image object detection. It is a loss function to solve the imbalance of categories and the difference of

View File

@ -34,7 +34,7 @@ from mindspore.train.callback import Callback, ModelCheckpoint
from mindspore.train import lineage_pb2 from mindspore.train import lineage_pb2
from mindspore.train.callback._dataset_graph import DatasetGraph from mindspore.train.callback._dataset_graph import DatasetGraph
from mindspore.nn.optim.optimizer import Optimizer from mindspore.nn.optim.optimizer import Optimizer
from mindspore.nn.loss.loss import Loss from mindspore.nn.loss.loss import LossBase
from mindspore.train._utils import check_value_type from mindspore.train._utils import check_value_type
HYPER_CONFIG_ENV_NAME = "MINDINSIGHT_HYPER_CONFIG" HYPER_CONFIG_ENV_NAME = "MINDINSIGHT_HYPER_CONFIG"
@ -909,7 +909,7 @@ class SummaryCollector(Callback):
network = cb_params.eval_network network = cb_params.eval_network
for _, cell in network.cells_and_names(): for _, cell in network.cells_and_names():
if isinstance(cell, Loss): if isinstance(cell, LossBase):
loss_fn = cell loss_fn = cell
break break
return loss_fn return loss_fn

View File

@ -14,13 +14,13 @@
# ============================================================================ # ============================================================================
"""CTC Loss.""" """CTC Loss."""
import numpy as np import numpy as np
from mindspore.nn.loss.loss import _Loss from mindspore.nn.loss.loss import LossBase
from mindspore import Tensor, Parameter from mindspore import Tensor, Parameter
from mindspore.common import dtype as mstype from mindspore.common import dtype as mstype
from mindspore.ops import operations as P from mindspore.ops import operations as P
class CTCLoss(_Loss): class CTCLoss(LossBase):
""" """
CTCLoss definition CTCLoss definition

View File

@ -24,7 +24,7 @@ import mindspore.nn as nn
import mindspore.ops.operations as P import mindspore.ops.operations as P
import mindspore.common.dtype as mstype import mindspore.common.dtype as mstype
from mindspore import context, Tensor from mindspore import context, Tensor
from mindspore.nn.loss.loss import Loss from mindspore.nn.loss.loss import LossBase
from mindspore.ops import composite as C from mindspore.ops import composite as C
from mindspore.ops import functional as F from mindspore.ops import functional as F
from mindspore.context import ParallelMode from mindspore.context import ParallelMode
@ -34,7 +34,7 @@ from mindspore.parallel._auto_parallel_context import auto_parallel_context
from src.seq2seq import Encoder, Decoder from src.seq2seq import Encoder, Decoder
class NLLLoss(Loss): class NLLLoss(LossBase):
def __init__(self, reduction='mean'): def __init__(self, reduction='mean'):
super(NLLLoss, self).__init__(reduction) super(NLLLoss, self).__init__(reduction)
self.one_hot = P.OneHot() self.one_hot = P.OneHot()

View File

@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
"""define loss function for network.""" """define loss function for network."""
from mindspore.nn.loss.loss import Loss from mindspore.nn.loss.loss import LossBase
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.ops import functional as F from mindspore.ops import functional as F
from mindspore import Tensor from mindspore import Tensor
@ -21,7 +21,7 @@ from mindspore.common import dtype as mstype
import mindspore.nn as nn import mindspore.nn as nn
class CrossEntropy(Loss): class CrossEntropy(LossBase):
"""the redefined loss function with SoftmaxCrossEntropyWithLogits""" """the redefined loss function with SoftmaxCrossEntropyWithLogits"""
def __init__(self, smooth_factor=0., num_classes=1000): def __init__(self, smooth_factor=0., num_classes=1000):
super(CrossEntropy, self).__init__() super(CrossEntropy, self).__init__()

View File

@ -13,14 +13,14 @@
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
"""define loss function for network.""" """define loss function for network."""
from mindspore.nn.loss.loss import Loss from mindspore.nn.loss.loss import LossBase
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.ops import functional as F from mindspore.ops import functional as F
from mindspore.common import dtype as mstype from mindspore.common import dtype as mstype
from mindspore import Tensor from mindspore import Tensor
import mindspore.nn as nn import mindspore.nn as nn
class LabelSmoothingCrossEntropy(Loss): class LabelSmoothingCrossEntropy(LossBase):
def __init__(self, smooth_factor=0.1, num_classes=1000): def __init__(self, smooth_factor=0.1, num_classes=1000):
super(LabelSmoothingCrossEntropy, self).__init__() super(LabelSmoothingCrossEntropy, self).__init__()

View File

@ -16,12 +16,12 @@
import mindspore.nn as nn import mindspore.nn as nn
from mindspore import Tensor from mindspore import Tensor
from mindspore.common import dtype as mstype from mindspore.common import dtype as mstype
from mindspore.nn.loss.loss import Loss from mindspore.nn.loss.loss import LossBase
from mindspore.ops import functional as F from mindspore.ops import functional as F
from mindspore.ops import operations as P from mindspore.ops import operations as P
class CrossEntropySmooth(Loss): class CrossEntropySmooth(LossBase):
"""CrossEntropy""" """CrossEntropy"""
def __init__(self, sparse=True, reduction='mean', smooth_factor=0., num_classes=1000): def __init__(self, sparse=True, reduction='mean', smooth_factor=0., num_classes=1000):
super(CrossEntropySmooth, self).__init__() super(CrossEntropySmooth, self).__init__()

View File

@ -16,12 +16,12 @@
import mindspore.nn as nn import mindspore.nn as nn
from mindspore import Tensor from mindspore import Tensor
from mindspore.common import dtype as mstype from mindspore.common import dtype as mstype
from mindspore.nn.loss.loss import Loss from mindspore.nn.loss.loss import LossBase
from mindspore.ops import functional as F from mindspore.ops import functional as F
from mindspore.ops import operations as P from mindspore.ops import operations as P
class CrossEntropySmooth(Loss): class CrossEntropySmooth(LossBase):
"""CrossEntropy""" """CrossEntropy"""
def __init__(self, sparse=True, reduction='mean', smooth_factor=0., num_classes=1000): def __init__(self, sparse=True, reduction='mean', smooth_factor=0., num_classes=1000):
super(CrossEntropySmooth, self).__init__() super(CrossEntropySmooth, self).__init__()

View File

@ -19,12 +19,12 @@ from mindspore import nn
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.ops import functional as F from mindspore.ops import functional as F
from mindspore.common import dtype as mstype from mindspore.common import dtype as mstype
from mindspore.nn.loss.loss import Loss from mindspore.nn.loss.loss import LossBase
from mindspore.train.callback import Callback from mindspore.train.callback import Callback
from mindspore.train.serialization import load_checkpoint, load_param_into_net from mindspore.train.serialization import load_checkpoint, load_param_into_net
from src.mobilenetV2 import MobileNetV2Backbone, MobileNetV2Head, mobilenet_v2 from src.mobilenetV2 import MobileNetV2Backbone, MobileNetV2Head, mobilenet_v2
class CrossEntropyWithLabelSmooth(Loss): class CrossEntropyWithLabelSmooth(LossBase):
""" """
CrossEntropyWith LabelSmooth. CrossEntropyWith LabelSmooth.

View File

@ -20,7 +20,7 @@ import numpy as np
from mindspore.train.callback import Callback from mindspore.train.callback import Callback
from mindspore import Tensor from mindspore import Tensor
from mindspore import nn from mindspore import nn
from mindspore.nn.loss.loss import Loss from mindspore.nn.loss.loss import LossBase
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.ops import functional as F from mindspore.ops import functional as F
from mindspore.common import dtype as mstype from mindspore.common import dtype as mstype
@ -78,7 +78,7 @@ class Monitor(Callback):
np.mean(self.losses), step_mseconds, self.lr_init[cb_params.cur_step_num - 1])) np.mean(self.losses), step_mseconds, self.lr_init[cb_params.cur_step_num - 1]))
class CrossEntropyWithLabelSmooth(Loss): class CrossEntropyWithLabelSmooth(LossBase):
""" """
CrossEntropyWith LabelSmooth. CrossEntropyWith LabelSmooth.

View File

@ -24,7 +24,7 @@ from mindspore import Tensor
from mindspore import nn from mindspore import nn
from mindspore.nn.optim.momentum import Momentum from mindspore.nn.optim.momentum import Momentum
from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits from mindspore.nn.loss import SoftmaxCrossEntropyWithLogits
from mindspore.nn.loss.loss import Loss from mindspore.nn.loss.loss import LossBase
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.ops import functional as F from mindspore.ops import functional as F
from mindspore.common import dtype as mstype from mindspore.common import dtype as mstype
@ -69,7 +69,7 @@ else:
raise ValueError("Unsupported device_target.") raise ValueError("Unsupported device_target.")
class CrossEntropyWithLabelSmooth(Loss): class CrossEntropyWithLabelSmooth(LossBase):
""" """
CrossEntropyWith LabelSmooth. CrossEntropyWith LabelSmooth.

View File

@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
"""define evaluation loss function for network.""" """define evaluation loss function for network."""
from mindspore.nn.loss.loss import Loss from mindspore.nn.loss.loss import LossBase
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.ops import functional as F from mindspore.ops import functional as F
from mindspore import Tensor from mindspore import Tensor
@ -21,7 +21,7 @@ from mindspore.common import dtype as mstype
import mindspore.nn as nn import mindspore.nn as nn
class CrossEntropy_Val(Loss): class CrossEntropy_Val(LossBase):
"""the redefined loss function with SoftmaxCrossEntropyWithLogits, will be used in inference process""" """the redefined loss function with SoftmaxCrossEntropyWithLogits, will be used in inference process"""
def __init__(self, smooth_factor=0, num_classes=1000): def __init__(self, smooth_factor=0, num_classes=1000):
super(CrossEntropy_Val, self).__init__() super(CrossEntropy_Val, self).__init__()

View File

@ -18,7 +18,7 @@ import numpy as np
from mindspore import context from mindspore import context
from mindspore import Tensor from mindspore import Tensor
import mindspore.nn as nn import mindspore.nn as nn
from mindspore.nn.loss.loss import Loss from mindspore.nn.loss.loss import LossBase
import mindspore.ops.operations as P import mindspore.ops.operations as P
import mindspore.ops.functional as F import mindspore.ops.functional as F
import mindspore.ops.composite as C import mindspore.ops.composite as C
@ -57,7 +57,7 @@ def _clip_grad(clip_type, clip_value, grad):
return new_grad return new_grad
class CrossEntropy(Loss): class CrossEntropy(LossBase):
"""the redefined loss function with SoftmaxCrossEntropyWithLogits""" """the redefined loss function with SoftmaxCrossEntropyWithLogits"""
def __init__(self, smooth_factor=0, num_classes=1000, factor=0.4): def __init__(self, smooth_factor=0, num_classes=1000, factor=0.4):
super(CrossEntropy, self).__init__() super(CrossEntropy, self).__init__()

View File

@ -16,12 +16,12 @@
import mindspore.nn as nn import mindspore.nn as nn
from mindspore import Tensor from mindspore import Tensor
from mindspore.common import dtype as mstype from mindspore.common import dtype as mstype
from mindspore.nn.loss.loss import Loss from mindspore.nn.loss.loss import LossBase
from mindspore.ops import functional as F from mindspore.ops import functional as F
from mindspore.ops import operations as P from mindspore.ops import operations as P
class CrossEntropySmooth(Loss): class CrossEntropySmooth(LossBase):
"""CrossEntropy""" """CrossEntropy"""
def __init__(self, sparse=True, reduction='mean', smooth_factor=0., num_classes=1000): def __init__(self, sparse=True, reduction='mean', smooth_factor=0., num_classes=1000):
super(CrossEntropySmooth, self).__init__() super(CrossEntropySmooth, self).__init__()

View File

@ -16,12 +16,12 @@
import mindspore.nn as nn import mindspore.nn as nn
from mindspore import Tensor from mindspore import Tensor
from mindspore.common import dtype as mstype from mindspore.common import dtype as mstype
from mindspore.nn.loss.loss import Loss from mindspore.nn.loss.loss import LossBase
from mindspore.ops import functional as F from mindspore.ops import functional as F
from mindspore.ops import operations as P from mindspore.ops import operations as P
class CrossEntropySmooth(Loss): class CrossEntropySmooth(LossBase):
"""CrossEntropy""" """CrossEntropy"""
def __init__(self, sparse=True, reduction='mean', smooth_factor=0., num_classes=1000): def __init__(self, sparse=True, reduction='mean', smooth_factor=0., num_classes=1000):
super(CrossEntropySmooth, self).__init__() super(CrossEntropySmooth, self).__init__()

View File

@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
"""define loss function for network""" """define loss function for network"""
from mindspore.nn.loss.loss import Loss from mindspore.nn.loss.loss import LossBase
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.ops import functional as F from mindspore.ops import functional as F
from mindspore import Tensor from mindspore import Tensor
@ -21,7 +21,7 @@ from mindspore.common import dtype as mstype
import mindspore.nn as nn import mindspore.nn as nn
class CrossEntropy(Loss): class CrossEntropy(LossBase):
"""the redefined loss function with SoftmaxCrossEntropyWithLogits""" """the redefined loss function with SoftmaxCrossEntropyWithLogits"""
def __init__(self, smooth_factor=0, num_classes=1001): def __init__(self, smooth_factor=0, num_classes=1001):

View File

@ -72,7 +72,7 @@ class Loss(nn.Cell):
raise NotImplementedError raise NotImplementedError
class CrossEntropy(Loss): class CrossEntropy(LossBase):
"""CrossEntropy""" """CrossEntropy"""
def __init__(self, smooth_factor=0., num_classes=1000): def __init__(self, smooth_factor=0., num_classes=1000):
super(CrossEntropy, self).__init__() super(CrossEntropy, self).__init__()

View File

@ -15,14 +15,14 @@
""" """
define loss function for network. define loss function for network.
""" """
from mindspore.nn.loss.loss import Loss from mindspore.nn.loss.loss import LossBase
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.ops import functional as F from mindspore.ops import functional as F
from mindspore import Tensor from mindspore import Tensor
from mindspore.common import dtype as mstype from mindspore.common import dtype as mstype
import mindspore.nn as nn import mindspore.nn as nn
class CrossEntropy(Loss): class CrossEntropy(LossBase):
""" """
the redefined loss function with SoftmaxCrossEntropyWithLogits. the redefined loss function with SoftmaxCrossEntropyWithLogits.
""" """

View File

@ -16,12 +16,12 @@
import mindspore.nn as nn import mindspore.nn as nn
from mindspore import Tensor from mindspore import Tensor
from mindspore.common import dtype as mstype from mindspore.common import dtype as mstype
from mindspore.nn.loss.loss import Loss from mindspore.nn.loss.loss import LossBase
from mindspore.ops import functional as F from mindspore.ops import functional as F
from mindspore.ops import operations as P from mindspore.ops import operations as P
class CrossEntropySmooth(Loss): class CrossEntropySmooth(LossBase):
"""CrossEntropy""" """CrossEntropy"""
def __init__(self, sparse=True, reduction='mean', smooth_factor=0., num_classes=1000): def __init__(self, sparse=True, reduction='mean', smooth_factor=0., num_classes=1000):
super(CrossEntropySmooth, self).__init__() super(CrossEntropySmooth, self).__init__()

View File

@ -15,11 +15,11 @@
import mindspore.nn as nn import mindspore.nn as nn
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.ops import functional as F from mindspore.ops import functional as F
from mindspore.nn.loss.loss import Loss from mindspore.nn.loss.loss import LossBase
from mindspore.common import dtype as mstype from mindspore.common import dtype as mstype
class JointsMSELoss(Loss): class JointsMSELoss(LossBase):
def __init__(self, use_target_weight): def __init__(self, use_target_weight):
super(JointsMSELoss, self).__init__() super(JointsMSELoss, self).__init__()
self.criterion = nn.MSELoss(reduction='mean') self.criterion = nn.MSELoss(reduction='mean')

View File

@ -16,12 +16,12 @@
import mindspore.nn as nn import mindspore.nn as nn
from mindspore import Tensor from mindspore import Tensor
from mindspore.common import dtype as mstype from mindspore.common import dtype as mstype
from mindspore.nn.loss.loss import Loss from mindspore.nn.loss.loss import LossBase
from mindspore.ops import functional as F from mindspore.ops import functional as F
from mindspore.ops import operations as P from mindspore.ops import operations as P
class CrossEntropySmooth(Loss): class CrossEntropySmooth(LossBase):
"""CrossEntropy""" """CrossEntropy"""
def __init__(self, sparse=True, reduction='mean', smooth_factor=0., num_classes=1000): def __init__(self, sparse=True, reduction='mean', smooth_factor=0., num_classes=1000):
super(CrossEntropySmooth, self).__init__() super(CrossEntropySmooth, self).__init__()

View File

@ -16,10 +16,10 @@
import mindspore.nn as nn import mindspore.nn as nn
from mindspore import dtype as mstype from mindspore import dtype as mstype
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.nn.loss.loss import Loss from mindspore.nn.loss.loss import LossBase
from src.model_utils.config import config from src.model_utils.config import config
class SoftmaxCrossEntropyWithLogits(Loss): class SoftmaxCrossEntropyWithLogits(LossBase):
def __init__(self): def __init__(self):
super(SoftmaxCrossEntropyWithLogits, self).__init__() super(SoftmaxCrossEntropyWithLogits, self).__init__()
self.transpose = P.Transpose() self.transpose = P.Transpose()

View File

@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
"""define loss function for network""" """define loss function for network"""
from mindspore.nn.loss.loss import Loss from mindspore.nn.loss.loss import LossBase
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.ops import functional as F from mindspore.ops import functional as F
from mindspore import Tensor from mindspore import Tensor
@ -21,7 +21,7 @@ from mindspore.common import dtype as mstype
import mindspore.nn as nn import mindspore.nn as nn
class CrossEntropy(Loss): class CrossEntropy(LossBase):
"""the redefined loss function with SoftmaxCrossEntropyWithLogits""" """the redefined loss function with SoftmaxCrossEntropyWithLogits"""
def __init__(self, smooth_factor=0., num_classes=1001): def __init__(self, smooth_factor=0., num_classes=1001):

View File

@ -14,13 +14,13 @@
# ============================================================================ # ============================================================================
"""CTC Loss.""" """CTC Loss."""
import numpy as np import numpy as np
from mindspore.nn.loss.loss import Loss from mindspore.nn.loss.loss import LossBase
from mindspore import Tensor, Parameter from mindspore import Tensor, Parameter
from mindspore.common import dtype as mstype from mindspore.common import dtype as mstype
from mindspore.ops import operations as P from mindspore.ops import operations as P
class CTCLoss(Loss): class CTCLoss(LossBase):
""" """
CTCLoss definition CTCLoss definition

View File

@ -14,10 +14,10 @@
# ============================================================================ # ============================================================================
"""NLLLoss cell""" """NLLLoss cell"""
import mindspore.ops.operations as P import mindspore.ops.operations as P
from mindspore.nn.loss.loss import Loss from mindspore.nn.loss.loss import LossBase
from mindspore.ops import functional as F from mindspore.ops import functional as F
class NLLLoss(Loss): class NLLLoss(LossBase):
''' '''
NLLLoss function NLLLoss function
''' '''

View File

@ -17,7 +17,7 @@ import mindspore.nn as nn
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.ops import functional as F from mindspore.ops import functional as F
from mindspore.common import dtype as mstype from mindspore.common import dtype as mstype
from mindspore.nn.loss.loss import Loss from mindspore.nn.loss.loss import LossBase
from mindspore import Tensor from mindspore import Tensor
eps = 1e-24 eps = 1e-24
@ -41,7 +41,7 @@ class log_softmax(nn.Cell):
return result return result
class CEWithIgnoreIndex3D(Loss): class CEWithIgnoreIndex3D(LossBase):
'''CEWithIgnoreIndex3D''' '''CEWithIgnoreIndex3D'''
def __init__(self): def __init__(self):
super(CEWithIgnoreIndex3D, self).__init__() super(CEWithIgnoreIndex3D, self).__init__()

View File

@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
"""Face Recognition loss.""" """Face Recognition loss."""
from mindspore.nn.loss.loss import Loss from mindspore.nn.loss.loss import LossBase
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.ops import functional as F from mindspore.ops import functional as F
from mindspore import Tensor from mindspore import Tensor
@ -23,7 +23,7 @@ import mindspore.nn as nn
eps = 1e-24 eps = 1e-24
class CrossEntropyNew(Loss): class CrossEntropyNew(LossBase):
'''CrossEntropyNew''' '''CrossEntropyNew'''
def __init__(self, smooth_factor=0., num_classes=1000): def __init__(self, smooth_factor=0., num_classes=1000):
super(CrossEntropyNew, self).__init__() super(CrossEntropyNew, self).__init__()
@ -42,7 +42,7 @@ class CrossEntropyNew(Loss):
return loss return loss
class CrossEntropy(Loss): class CrossEntropy(LossBase):
'''CrossEntropy''' '''CrossEntropy'''
def __init__(self): def __init__(self):
super(CrossEntropy, self).__init__() super(CrossEntropy, self).__init__()
@ -106,7 +106,7 @@ class CrossEntropyWithIgnoreIndex(nn.Cell):
eps = 1e-24 eps = 1e-24
class CEWithIgnoreIndex3D(Loss): class CEWithIgnoreIndex3D(LossBase):
'''CEWithIgnoreIndex3D''' '''CEWithIgnoreIndex3D'''
def __init__(self): def __init__(self):
super(CEWithIgnoreIndex3D, self).__init__() super(CEWithIgnoreIndex3D, self).__init__()

View File

@ -16,12 +16,12 @@
import mindspore.nn as nn import mindspore.nn as nn
from mindspore import Tensor from mindspore import Tensor
from mindspore.common import dtype as mstype from mindspore.common import dtype as mstype
from mindspore.nn.loss.loss import Loss from mindspore.nn.loss.loss import LossBase
from mindspore.ops import functional as F from mindspore.ops import functional as F
from mindspore.ops import operations as P from mindspore.ops import operations as P
class CrossEntropySmooth(Loss): class CrossEntropySmooth(LossBase):
"""CrossEntropy""" """CrossEntropy"""
def __init__(self, sparse=True, reduction='mean', smooth_factor=0., num_classes=1000): def __init__(self, sparse=True, reduction='mean', smooth_factor=0., num_classes=1000):
super(CrossEntropySmooth, self).__init__() super(CrossEntropySmooth, self).__init__()

View File

@ -14,7 +14,7 @@
# ============================================================================ # ============================================================================
"""define loss function for network.""" """define loss function for network."""
from mindspore.nn.loss.loss import Loss from mindspore.nn.loss.loss import LossBase
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.ops import functional as F from mindspore.ops import functional as F
from mindspore import Tensor from mindspore import Tensor
@ -22,7 +22,7 @@ from mindspore.common import dtype as mstype
import mindspore.nn as nn import mindspore.nn as nn
class LabelSmoothingCrossEntropy(Loss): class LabelSmoothingCrossEntropy(LossBase):
"""cross-entropy with label smoothing""" """cross-entropy with label smoothing"""
def __init__(self, smooth_factor=0.1, num_classes=1000): def __init__(self, smooth_factor=0.1, num_classes=1000):

View File

@ -16,7 +16,7 @@
import mindspore.nn as nn import mindspore.nn as nn
from mindspore import Tensor from mindspore import Tensor
from mindspore.common import dtype as mstype from mindspore.common import dtype as mstype
from mindspore.nn.loss.loss import Loss from mindspore.nn.loss.loss import LossBase
from mindspore.ops import functional as F from mindspore.ops import functional as F
from mindspore.ops import operations as P from mindspore.ops import operations as P
import mindspore.ops as ops import mindspore.ops as ops
@ -57,7 +57,7 @@ class SoftmaxCrossEntropyExpand(nn.Cell): # pylint: disable=missing-docstring
return loss return loss
class CrossEntropySmooth(Loss): class CrossEntropySmooth(LossBase):
"""CrossEntropy""" """CrossEntropy"""
def __init__(self, sparse=True, reduction='mean', smooth_factor=0., num_classes=1000): def __init__(self, sparse=True, reduction='mean', smooth_factor=0., num_classes=1000):

View File

@ -16,12 +16,12 @@
import mindspore.nn as nn import mindspore.nn as nn
from mindspore import Tensor from mindspore import Tensor
from mindspore.common import dtype as mstype from mindspore.common import dtype as mstype
from mindspore.nn.loss.loss import Loss from mindspore.nn.loss.loss import LossBase
from mindspore.ops import functional as F from mindspore.ops import functional as F
from mindspore.ops import operations as P from mindspore.ops import operations as P
class CrossEntropySmooth(Loss): class CrossEntropySmooth(LossBase):
"""CrossEntropy""" """CrossEntropy"""
def __init__(self, sparse=True, reduction='mean', smooth_factor=0., num_classes=1000): def __init__(self, sparse=True, reduction='mean', smooth_factor=0., num_classes=1000):
super(CrossEntropySmooth, self).__init__() super(CrossEntropySmooth, self).__init__()

View File

@ -14,7 +14,7 @@
# ============================================================================ # ============================================================================
"""define loss function for network.""" """define loss function for network."""
from mindspore.nn.loss.loss import Loss from mindspore.nn.loss.loss import LossBase
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.ops import functional as F from mindspore.ops import functional as F
from mindspore import Tensor from mindspore import Tensor
@ -22,7 +22,7 @@ from mindspore.common import dtype as mstype
import mindspore.nn as nn import mindspore.nn as nn
class LabelSmoothingCrossEntropy(Loss): class LabelSmoothingCrossEntropy(LossBase):
"""cross-entropy with label smoothing""" """cross-entropy with label smoothing"""
def __init__(self, smooth_factor=0.1, num_classes=1000): def __init__(self, smooth_factor=0.1, num_classes=1000):

View File

@ -16,12 +16,12 @@
import mindspore.nn as nn import mindspore.nn as nn
from mindspore import Tensor from mindspore import Tensor
from mindspore.common import dtype as mstype from mindspore.common import dtype as mstype
from mindspore.nn.loss.loss import Loss from mindspore.nn.loss.loss import LossBase
from mindspore.ops import functional as F from mindspore.ops import functional as F
from mindspore.ops import operations as P from mindspore.ops import operations as P
class CrossEntropySmooth(Loss): class CrossEntropySmooth(LossBase):
"""CrossEntropy""" """CrossEntropy"""
def __init__(self, sparse=True, reduction='mean', smooth_factor=0., num_classes=1000): def __init__(self, sparse=True, reduction='mean', smooth_factor=0., num_classes=1000):
super(CrossEntropySmooth, self).__init__() super(CrossEntropySmooth, self).__init__()

View File

@ -14,7 +14,7 @@
# ============================================================================ # ============================================================================
"""define loss function for network.""" """define loss function for network."""
from mindspore.nn.loss.loss import Loss from mindspore.nn.loss.loss import LossBase
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.ops import functional as F from mindspore.ops import functional as F
from mindspore import Tensor from mindspore import Tensor
@ -22,7 +22,7 @@ from mindspore.common import dtype as mstype
import mindspore.nn as nn import mindspore.nn as nn
class LabelSmoothingCrossEntropy(Loss): class LabelSmoothingCrossEntropy(LossBase):
"""cross-entropy with label smoothing""" """cross-entropy with label smoothing"""
def __init__(self, smooth_factor=0.1, num_classes=1000): def __init__(self, smooth_factor=0.1, num_classes=1000):

View File

@ -13,7 +13,7 @@
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
"""define loss function for network""" """define loss function for network"""
from mindspore.nn.loss.loss import Loss from mindspore.nn.loss.loss import LossBase
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.ops import functional as F from mindspore.ops import functional as F
from mindspore import Tensor from mindspore import Tensor
@ -21,7 +21,7 @@ from mindspore.common import dtype as mstype
import mindspore.nn as nn import mindspore.nn as nn
class CrossEntropy(Loss): class CrossEntropy(LossBase):
"""the redefined loss function with SoftmaxCrossEntropyWithLogits""" """the redefined loss function with SoftmaxCrossEntropyWithLogits"""
def __init__(self, smooth_factor=0., num_classes=1001): def __init__(self, smooth_factor=0., num_classes=1001):

View File

@ -14,14 +14,14 @@
# =========================================================================== # ===========================================================================
"""DSCNN loss.""" """DSCNN loss."""
import mindspore.nn as nn import mindspore.nn as nn
from mindspore.nn.loss.loss import Loss from mindspore.nn.loss.loss import LossBase
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.ops import functional as F from mindspore.ops import functional as F
from mindspore import Tensor from mindspore import Tensor
from mindspore.common import dtype as mstype from mindspore.common import dtype as mstype
class CrossEntropy(Loss): class CrossEntropy(LossBase):
'''Build CrossEntropy Loss.''' '''Build CrossEntropy Loss.'''
def __init__(self, smooth_factor=0., num_classes=1000): def __init__(self, smooth_factor=0., num_classes=1000):
super(CrossEntropy, self).__init__() super(CrossEntropy, self).__init__()

View File

@ -23,7 +23,7 @@ import mindspore.ops.functional as F
from mindspore import Tensor from mindspore import Tensor
from mindspore.common.initializer import TruncatedNormal from mindspore.common.initializer import TruncatedNormal
from mindspore.communication.management import init from mindspore.communication.management import init
from mindspore.nn.loss.loss import Loss from mindspore.nn.loss.loss import LossBase
from mindspore.nn.optim.momentum import Momentum from mindspore.nn.optim.momentum import Momentum
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.parallel import set_algo_parameters from mindspore.parallel import set_algo_parameters
@ -245,7 +245,7 @@ def resnet50(class_num=10):
class_num) class_num)
class SoftmaxCrossEntropyExpand(Loss): class SoftmaxCrossEntropyExpand(LossBase):
def __init__(self, sparse=False): def __init__(self, sparse=False):
super(SoftmaxCrossEntropyExpand, self).__init__() super(SoftmaxCrossEntropyExpand, self).__init__()
self.exp = P.Exp() self.exp = P.Exp()

View File

@ -16,12 +16,12 @@
import mindspore.nn as nn import mindspore.nn as nn
from mindspore import Tensor from mindspore import Tensor
from mindspore.common import dtype as mstype from mindspore.common import dtype as mstype
from mindspore.nn.loss.loss import Loss from mindspore.nn.loss.loss import LossBase
from mindspore.ops import functional as F from mindspore.ops import functional as F
from mindspore.ops import operations as P from mindspore.ops import operations as P
class CrossEntropySmooth(Loss): class CrossEntropySmooth(LossBase):
"""CrossEntropy""" """CrossEntropy"""
def __init__(self, sparse=True, reduction='mean', smooth_factor=0., num_classes=1000): def __init__(self, sparse=True, reduction='mean', smooth_factor=0., num_classes=1000):
super(CrossEntropySmooth, self).__init__() super(CrossEntropySmooth, self).__init__()

View File

@ -18,11 +18,11 @@ import pytest
from mindspore import Tensor from mindspore import Tensor
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.nn.loss.loss import Loss from mindspore.nn.loss.loss import LossBase
from mindspore.nn.loss.loss import L1Loss from mindspore.nn.loss.loss import L1Loss
import mindspore.context as context import mindspore.context as context
class WeightedLoss(Loss): class WeightedLoss(LossBase):
def __init__(self, reduction='mean', weights=1.0): def __init__(self, reduction='mean', weights=1.0):
super(WeightedLoss, self).__init__(reduction) super(WeightedLoss, self).__init__(reduction)
self.abs = P.Abs() self.abs = P.Abs()
@ -72,7 +72,7 @@ def test_weighted_loss_float32():
def test_weighted_loss_float64(): def test_weighted_loss_float64():
weighted_loss(np.float64) weighted_loss(np.float64)
class CustomLoss(Loss): class CustomLoss(LossBase):
def __init__(self, reduction='mean'): def __init__(self, reduction='mean'):
super(CustomLoss, self).__init__(reduction) super(CustomLoss, self).__init__(reduction)
self.abs = P.Abs() self.abs = P.Abs()

View File

@ -20,7 +20,7 @@ import numpy as np
from mindspore.train.callback import Callback from mindspore.train.callback import Callback
from mindspore import Tensor from mindspore import Tensor
from mindspore import nn from mindspore import nn
from mindspore.nn.loss.loss import Loss from mindspore.nn.loss.loss import LossBase
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.ops import functional as F from mindspore.ops import functional as F
from mindspore.common import dtype as mstype from mindspore.common import dtype as mstype
@ -87,7 +87,7 @@ class Monitor(Callback):
run_context.request_stop() run_context.request_stop()
class CrossEntropyWithLabelSmooth(Loss): class CrossEntropyWithLabelSmooth(LossBase):
""" """
CrossEntropyWith LabelSmooth. CrossEntropyWith LabelSmooth.

View File

@ -20,7 +20,7 @@ import numpy as np
from mindspore.train.callback import Callback from mindspore.train.callback import Callback
from mindspore import Tensor from mindspore import Tensor
from mindspore import nn from mindspore import nn
from mindspore.nn.loss.loss import Loss from mindspore.nn.loss.loss import LossBase
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.ops import functional as F from mindspore.ops import functional as F
from mindspore.common import dtype as mstype from mindspore.common import dtype as mstype
@ -85,7 +85,7 @@ class Monitor(Callback):
run_context.request_stop() run_context.request_stop()
class CrossEntropy(Loss): class CrossEntropy(LossBase):
"""the redefined loss function with SoftmaxCrossEntropyWithLogits""" """the redefined loss function with SoftmaxCrossEntropyWithLogits"""
def __init__(self, smooth_factor=0, num_classes=1001): def __init__(self, smooth_factor=0, num_classes=1001):

View File

@ -23,7 +23,7 @@ from mindspore import context
from mindspore.common.api import _executor from mindspore.common.api import _executor
from mindspore.common.initializer import TruncatedNormal from mindspore.common.initializer import TruncatedNormal
from mindspore.communication.management import init from mindspore.communication.management import init
from mindspore.nn.loss.loss import Loss from mindspore.nn.loss.loss import LossBase
from mindspore.nn.optim.momentum import Momentum from mindspore.nn.optim.momentum import Momentum
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.parallel import _cost_model_context as cost_model_context from mindspore.parallel import _cost_model_context as cost_model_context
@ -220,7 +220,7 @@ def resnet50(class_num=10):
class_num) class_num)
class SoftmaxCrossEntropyExpand(Loss): class SoftmaxCrossEntropyExpand(LossBase):
def __init__(self, sparse=False): def __init__(self, sparse=False):
super(SoftmaxCrossEntropyExpand, self).__init__() super(SoftmaxCrossEntropyExpand, self).__init__()
self.exp = P.Exp() self.exp = P.Exp()

View File

@ -20,7 +20,7 @@ from mindspore import Tensor, Parameter
from mindspore import context from mindspore import context
from mindspore.common import dtype as mstype from mindspore.common import dtype as mstype
from mindspore.common.api import _executor from mindspore.common.api import _executor
from mindspore.nn.loss.loss import Loss from mindspore.nn.loss.loss import LossBase
from mindspore.ops import composite as C from mindspore.ops import composite as C
from mindspore.ops import operations as P from mindspore.ops import operations as P
from tests.ut.python.ops.test_math_ops import VirtualLoss from tests.ut.python.ops.test_math_ops import VirtualLoss
@ -59,7 +59,7 @@ class CustomMatMul(nn.Cell):
return out return out
class MarginCE(Loss): class MarginCE(LossBase):
def __init__(self): def __init__(self):
super(MarginCE, self).__init__() super(MarginCE, self).__init__()
self.fc = CustomMatMul(transpose_b=True) self.fc = CustomMatMul(transpose_b=True)

View File

@ -22,7 +22,7 @@ from mindspore.common import dtype as mstype
from mindspore.common.parameter import ParameterTuple from mindspore.common.parameter import ParameterTuple
from mindspore.communication.management import init from mindspore.communication.management import init
from mindspore.nn import Dense, Cell from mindspore.nn import Dense, Cell
from mindspore.nn.loss.loss import Loss from mindspore.nn.loss.loss import LossBase
from mindspore.nn.optim import Momentum from mindspore.nn.optim import Momentum
from mindspore.ops import composite as C from mindspore.ops import composite as C
from mindspore.ops import functional as F from mindspore.ops import functional as F
@ -64,7 +64,7 @@ class Dataset():
return self return self
class GatherV2(Loss): class GatherV2(LossBase):
def __init__(self, index_dim, strategy, index_size=16): def __init__(self, index_dim, strategy, index_size=16):
super(GatherV2, self).__init__() super(GatherV2, self).__init__()
self.pow = P.Pow() self.pow = P.Pow()
@ -195,7 +195,7 @@ def test_strategy3():
net_trains(criterion, rank) net_trains(criterion, rank)
class GatherV2Axis1(Loss): class GatherV2Axis1(LossBase):
def __init__(self, index_dim, strategy, index_size=16): def __init__(self, index_dim, strategy, index_size=16):
super(GatherV2Axis1, self).__init__() super(GatherV2Axis1, self).__init__()
self.pow = P.Pow() self.pow = P.Pow()

View File

@ -22,7 +22,7 @@ from mindspore import context
import mindspore.common.dtype as mstype import mindspore.common.dtype as mstype
from mindspore.common.api import _executor from mindspore.common.api import _executor
from mindspore.common.parameter import Parameter from mindspore.common.parameter import Parameter
from mindspore.nn.loss.loss import Loss from mindspore.nn.loss.loss import LossBase
from mindspore.nn.optim.momentum import Momentum from mindspore.nn.optim.momentum import Momentum
from mindspore.ops import operations as P from mindspore.ops import operations as P
from mindspore.ops import functional as F from mindspore.ops import functional as F
@ -68,7 +68,7 @@ class AllToAllNet(nn.Cell):
return x return x
class SoftmaxCrossEntropyWithLogits(Loss): class SoftmaxCrossEntropyWithLogits(LossBase):
def __init__(self, def __init__(self,
sparse=False, sparse=False,
reduction='none'): reduction='none'):