[feat][assistant][I3T91Y][I3T91Z]add new operator SoftMarginLoss and SoftMarginLossGrad

This commit is contained in:
SichongHao 2021-08-06 15:13:13 +08:00
parent e8e0976bb1
commit 3d2892b454
17 changed files with 437 additions and 6 deletions

View File

@ -312,6 +312,8 @@ inline const PrimitivePtr kPrimBinaryCrossEntropy = std::make_shared<Primitive>(
inline const PrimitivePtr kPrimBinaryCrossEntropyGrad = std::make_shared<Primitive>("BinaryCrossEntropyGrad");
inline const PrimitivePtr kPrimSmoothL1Loss = std::make_shared<Primitive>("SmoothL1Loss");
inline const PrimitivePtr kPrimSmoothL1LossGrad = std::make_shared<Primitive>("SmoothL1LossGrad");
inline const PrimitivePtr kPrimSoftMarginLoss = std::make_shared<Primitive>("SoftMarginLoss");
inline const PrimitivePtr kPrimSoftMarginLossGrad = std::make_shared<Primitive>("SoftMarginLossGrad");
inline const PrimitivePtr kPrimSoftmaxCrossEntropyWithLogits =
std::make_shared<Primitive>("SoftmaxCrossEntropyWithLogits");
inline const PrimitivePtr kPrimSigmoidCrossEntropyWithLogits =

View File

@ -0,0 +1,62 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ops/grad/soft_margin_loss_grad.h"
#include "ops/op_utils.h"
#include "utils/check_convert_utils.h"
#include "utils/tensor_construct_utils.h"
#include "abstract/primitive_infer_map.h"
namespace mindspore {
namespace ops {
namespace {
constexpr size_t kInputSize = 3;
abstract::ShapePtr SoftMarginLossGradInferShape(const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto op_name = primitive->name();
CheckAndConvertUtils::CheckInteger("input numbers", SizeToLong(input_args.size()), kEqual, kInputSize, op_name);
auto predict = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
auto label = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape];
auto dout = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[2]->BuildShape())[kShape];
CheckAndConvertUtils::Check("logits shape", predict, kEqual, "labels shape", label, op_name, ValueError);
if (dout.size() > 1) {
CheckAndConvertUtils::Check("logits shape", predict, kEqual, "dout shape", dout, op_name, ValueError);
}
return std::make_shared<abstract::Shape>(predict);
}
TypePtr SoftMarginLossGradInferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
auto op_name = primitive->name();
CheckAndConvertUtils::CheckInteger("input numbers", SizeToLong(input_args.size()), kEqual, kInputSize, op_name);
const std::set<TypePtr> valid_types = {kFloat16, kFloat32};
std::map<std::string, TypePtr> types;
types.emplace("logits", input_args[0]->BuildType());
types.emplace("labels", input_args[1]->BuildType());
types.emplace("dout", input_args[2]->BuildType());
CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, op_name);
return CheckAndConvertUtils::CheckTensorTypeValid("logits", input_args[0]->BuildType(), valid_types, op_name);
}
} // namespace
AbstractBasePtr SoftMarginLossGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
return abstract::MakeAbstract(SoftMarginLossGradInferShape(primitive, input_args),
SoftMarginLossGradInferType(primitive, input_args));
}
REGISTER_PRIMITIVE_EVAL_IMPL(SoftMarginLossGrad, prim::kPrimSoftMarginLossGrad, SoftMarginLossGradInfer, nullptr, true);
} // namespace ops
} // namespace mindspore

View File

@ -0,0 +1,43 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_OPS_SOFT_MARGIN_LOSS_GRAD_H_
#define MINDSPORE_CORE_OPS_SOFT_MARGIN_LOSS_GRAD_H_
#include <memory>
#include <map>
#include <vector>
#include <set>
#include <string>
#include "ops/primitive_c.h"
#include "abstract/abstract_value.h"
#include "utils/check_convert_utils.h"
namespace mindspore {
namespace ops {
constexpr auto kNameSoftMarginLossGrad = "SoftMarginLossGrad";
class SoftMarginLossGrad : public PrimitiveC {
public:
SoftMarginLossGrad() : PrimitiveC(kNameSoftMarginLossGrad) { InitIOName({"predict", "label", "dout"}, {"gradient"}); }
~SoftMarginLossGrad() = default;
MS_DECLARE_PARENT(SoftMarginLossGrad, PrimitiveC);
};
AbstractBasePtr SoftMarginLossGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args);
using PrimSoftMarginLossGradPtr = std::shared_ptr<SoftMarginLossGrad>;
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_OPS_SOFT_MARGIN_LOSS_GRAD_H_

View File

@ -0,0 +1,63 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "ops/soft_margin_loss.h"
#include "ops/op_utils.h"
#include "utils/check_convert_utils.h"
#include "utils/tensor_construct_utils.h"
#include "abstract/primitive_infer_map.h"
namespace mindspore {
namespace ops {
namespace {
constexpr size_t kInputSize = 2;
abstract::ShapePtr SoftMarginLossInferShape(const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto op_name = primitive->name();
CheckAndConvertUtils::CheckInteger("input numbers", SizeToLong(input_args.size()), kEqual, kInputSize, op_name);
auto predict = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
auto label = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape];
CheckAndConvertUtils::Check("logits shape", predict, kEqual, "labels shape", label, op_name, ValueError);
auto out_shape = predict;
int64_t reduction;
CheckAndConvertUtils::GetReductionEnumValue(primitive->GetAttr(kReduction), &reduction);
if (reduction == REDUCTION_SUM || reduction == MEAN) {
out_shape.resize(0);
}
return std::make_shared<abstract::Shape>(out_shape);
}
TypePtr SoftMarginLossInferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
auto op_name = primitive->name();
CheckAndConvertUtils::CheckInteger("input numbers", SizeToLong(input_args.size()), kEqual, kInputSize, op_name);
const std::set<TypePtr> valid_types = {kFloat16, kFloat32};
std::map<std::string, TypePtr> types;
types.emplace("logits", input_args[0]->BuildType());
types.emplace("labels", input_args[1]->BuildType());
CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, op_name);
return input_args[0]->BuildType();
}
} // namespace
AbstractBasePtr SoftMarginLossInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
return abstract::MakeAbstract(SoftMarginLossInferShape(primitive, input_args),
SoftMarginLossInferType(primitive, input_args));
}
REGISTER_PRIMITIVE_EVAL_IMPL(SoftMarginLoss, prim::kPrimSoftMarginLoss, SoftMarginLossInfer, nullptr, true);
} // namespace ops
} // namespace mindspore

View File

@ -0,0 +1,44 @@
/**
* Copyright 2021 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CORE_OPS_SOFT_MARGIN_LOSS_H_
#define MINDSPORE_CORE_OPS_SOFT_MARGIN_LOSS_H_
#include <memory>
#include <map>
#include <vector>
#include <set>
#include <string>
#include "ops/primitive_c.h"
#include "abstract/abstract_value.h"
#include "utils/check_convert_utils.h"
namespace mindspore {
namespace ops {
constexpr auto kNameSoftMarginLoss = "SoftMarginLoss";
class SoftMarginLoss : public PrimitiveC {
public:
SoftMarginLoss() : PrimitiveC(kNameSoftMarginLoss) { InitIOName({"predict", "label"}, {"loss"}); }
~SoftMarginLoss() = default;
MS_DECLARE_PARENT(SoftMarginLoss, PrimitiveC);
};
AbstractBasePtr SoftMarginLossInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args);
using PrimSoftMarginLossPtr = std::shared_ptr<SoftMarginLoss>;
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_OPS_SOFT_MARGIN_LOSS_H_

View File

@ -175,6 +175,21 @@ void CheckAndConvertUtils::GetPadModEnumValue(const ValuePtr &value, int64_t *en
}
}
void CheckAndConvertUtils::GetReductionEnumValue(const ValuePtr &value, int64_t *enum_value) {
MS_EXCEPTION_IF_NULL(value);
if (value->isa<StringImm>()) {
auto attr_value_str = GetValue<std::string>(value);
std::map<std::string, int64_t> pad_map = ReductionToEnumMap;
if (pad_map.find(attr_value_str) == pad_map.end()) {
MS_LOG(EXCEPTION) << "Invalid pad mode " << attr_value_str << " use pad, valid or same";
}
*enum_value = pad_map[attr_value_str];
} else {
*enum_value = GetValue<int64_t>(value);
}
}
AttrConverterPair CheckAndConvertUtils::GetAttrConvertPair(const std::string &op_type, const std::string &attr_name) {
AttrConverterPair attr_pair;
if (op_type.empty() || attr_name.empty()) {

View File

@ -297,6 +297,7 @@ class CheckAndConvertUtils {
static AttrConverterPair GetAttrConvertPair(const std::string &op_type, const std::string &attr_name);
static bool GetDataFormatEnumValue(const ValuePtr &value, int64_t *enum_value);
static void GetPadModEnumValue(const ValuePtr &value, int64_t *enum_value, bool is_upper = false);
static void GetReductionEnumValue(const ValuePtr &value, int64_t *enum_value);
static bool CheckIrAttrtoOpAttr(const std::string &op_type, const std::string &attr_name, ValuePtr *const value);
static void CheckSummaryParam(const AbstractBasePtr &name, const AbstractBasePtr &value,
const std::string &class_name);

View File

@ -19,13 +19,13 @@ Cells of loss function. Loss function in machine learning is the target of the m
It shows how well the model works on a dataset and the optimization target which the optimizer is searching.
"""
from .loss import LossBase, L1Loss, MSELoss, SmoothL1Loss, FocalLoss,\
from .loss import LossBase, L1Loss, MSELoss, SmoothL1Loss, SoftMarginLoss, FocalLoss,\
SoftmaxCrossEntropyWithLogits, BCELoss, CosineEmbeddingLoss, \
SampledSoftmaxLoss, DiceLoss, BCEWithLogitsLoss, MultiClassDiceLoss,\
RMSELoss, MAELoss
__all__ = ['LossBase', 'L1Loss', 'MSELoss', 'SmoothL1Loss', 'FocalLoss',
__all__ = ['LossBase', 'L1Loss', 'MSELoss', 'SmoothL1Loss', 'SoftMarginLoss', 'FocalLoss',
'SoftmaxCrossEntropyWithLogits', 'BCELoss', 'BCEWithLogitsLoss',
'CosineEmbeddingLoss', 'SampledSoftmaxLoss', 'DiceLoss', 'MultiClassDiceLoss',
'RMSELoss', 'MAELoss']

View File

@ -436,6 +436,53 @@ class SmoothL1Loss(LossBase):
return self.smooth_l1_loss(base, target)
class SoftMarginLoss(LossBase):
r"""
A loss class for two-class classification problems.
SoftMarginLoss creates a criterion that optimizes a two-class classification
logistic loss between input tensor :math:`x` and target tensor :math:`y`
(containing 1 or -1).
.. math::
\text{loss}(x, y) = \sum_i \frac{\log(1 + \exp(-y[i]*x[i]))}{\text{x.nelement}()}
Args:
reduction (str): Apply specific reduction method to the output: 'none', 'mean', 'sum'. Default: "mean".
Inputs:
- **logits** (Tensor) - Predict data. Data type must be float16 or float32.
- **labels** (Tensor) - Ground truth data, with the same type and shape as `logits`.
Outputs:
Tensor or Scalar, if `reduction` is "none", its shape is the same as `logits`.
Otherwise, a scalar value will be returned.
Raises:
TypeError: If `logits` or `labels` is not a Tensor.
TypeError: If dtype of `logits` or `labels` is neither float16 nor float32.
ValueError: If shape of `logits` is not the same as `labels`.
ValueError: If `reduction` is not one of 'none', 'mean', 'sum'.
Supported Platforms:
``Ascend``
Examples:
>>> loss = ops.SoftMarginLoss()
>>> logits = Tensor(np.array([[0.3, 0.7], [0.5, 0.5]]), mindspore.float32)
>>> labels = Tensor(np.array([[-1, 1], [1, -1]]), mindspore.float32)
>>> output = loss(logits, labels)
>>> print(output)
0.6764238
"""
def __init__(self, reduction='mean'):
super(SoftMarginLoss, self).__init__()
self.soft_margin_loss = P.SoftMarginLoss(reduction)
def construct(self, base, target):
return self.soft_margin_loss(base, target)
class SoftmaxCrossEntropyWithLogits(LossBase):
r"""
Computes softmax cross entropy between logits and labels.
@ -1282,10 +1329,10 @@ class FocalLoss(LossBase):
convert_weight = self.squeeze(convert_weight)
log_probability = log_probability * convert_weight
weight = F.pows(-probability + 1.0, self.gamma)
weight = F.pows(-1 * probability + 1.0, self.gamma)
if target.shape[1] == 1:
loss = (-weight * log_probability).mean(axis=1)
loss = (-1 * weight * log_probability).mean(axis=1)
else:
loss = (-weight * targets * log_probability).mean(axis=-1)
loss = (-1 * weight * targets * log_probability).mean(axis=-1)
return self.get_loss(loss)

View File

@ -34,6 +34,19 @@ def get_bprop_ctc_loss_v2(self):
return bprop
@bprop_getters.register(P.SoftMarginLoss)
def get_bprop_soft_margin_loss(self):
"""Grad definition for `SoftMarginLoss` operation."""
grad = G.SoftMarginLossGrad(reduction=self.reduction)
def bprop(predict, label, out, dout):
dx = grad(predict, label, dout)
dy = grad(label, predict, dout)
return dx, dy
return bprop
@bprop_getters.register(P.SoftShrink)
def get_bprop_softshrink(self):
"""Grad definition for `SoftShrink` operation."""

View File

@ -220,6 +220,8 @@ from .arg_max_with_value import _arg_max_with_value_tbe
from .arg_min_with_value import _arg_min_with_value_tbe
from .smooth_l1_loss import _smooth_l1_loss_tbe
from .smooth_l1_loss_grad import _smooth_l1_loss_grad_tbe
from .soft_margin_loss import _soft_margin_loss_tbe
from .soft_margin_loss_grad import _soft_margin_loss_grad_tbe
from .fused_mul_add import _fused_mul_add_tbe
from .fused_mul_add_n import _fused_mul_add_n_tbe
from .fused_mul_apply_momentum import _fused_mul_apply_momentum_tbe

View File

@ -0,0 +1,38 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""SoftMarginLoss op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
soft_margin_loss_op_info = TBERegOp("SoftMarginLoss") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("soft_margin_loss.so") \
.compute_cost(10) \
.kernel_name("soft_margin_loss") \
.partial_flag(True) \
.attr("reduction", "optional", "str", "mean,none,sum", "mean") \
.input(0, "input_x", False, "required", "all") \
.input(1, "input_y", False, "required", "all") \
.output(0, "output_z", False, "required", "all") \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
.get_op_info()
@op_info_register(soft_margin_loss_op_info)
def _soft_margin_loss_tbe():
"""SoftMarginLoss TBE register"""
return

View File

@ -0,0 +1,39 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""SoftMarginLossGrad op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
soft_margin_loss_grad_op_info = TBERegOp("SoftMarginLossGrad") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("soft_margin_loss_grad.so") \
.compute_cost(10) \
.kernel_name("soft_margin_loss_grad") \
.partial_flag(True) \
.attr("reduction", "optional", "str", "mean,none,sum", "mean") \
.input(0, "predict", False, "required", "all") \
.input(1, "label", False, "required", "all") \
.input(2, "dout", False, "required", "all") \
.output(0, "gradient", False, "required", "all") \
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
.get_op_info()
@op_info_register(soft_margin_loss_grad_op_info)
def _soft_margin_loss_grad_tbe():
"""SoftMarginLossGrad TBE register"""
return

View File

@ -78,7 +78,7 @@ from .nn_ops import (LSTM, SGD, Adam, FusedSparseAdam, FusedSparseLazyAdam, Adam
MaxPoolWithArgmax, OneHot, Pad, MirrorPad, Mish, PReLU, ReLU, ReLU6, ReLUV2, HSwish, HSigmoid,
ResizeBilinear, Sigmoid, SeLU, HShrink,
SigmoidCrossEntropyWithLogits, NLLLoss, BCEWithLogitsLoss,
SmoothL1Loss, Softmax, Softsign, Softplus, LRN, RNNTLoss, DynamicRNN, DynamicGRUV2,
SmoothL1Loss, SoftMarginLoss, Softmax, Softsign, Softplus, LRN, RNNTLoss, DynamicRNN, DynamicGRUV2,
SoftmaxCrossEntropyWithLogits, ROIAlign,
SparseSoftmaxCrossEntropyWithLogits, Tanh,
TopK, BinaryCrossEntropy, KLDivLoss, SparseApplyAdagrad, LARSUpdate, ApplyFtrl, SparseApplyFtrl,
@ -286,6 +286,7 @@ __all__ = [
'FloatStatus',
'Reciprocal',
'SmoothL1Loss',
'SoftMarginLoss',
'L2Loss',
'CTCLoss',
'CTCGreedyDecoder',

View File

@ -1831,6 +1831,15 @@ class SmoothL1LossGrad(PrimitiveWithInfer):
return dloss
class SoftMarginLossGrad(Primitive):
"""Computes gradient for prediction on SoftMarginLoss."""
@prim_attr_register
def __init__(self, reduction="mean"):
self.init_prim_io_names(inputs=['predict', 'label', "dout"], outputs=['gradient'])
self.reduction = validator.check_string(reduction, ['none', 'sum', 'mean'], 'reduction', self.name)
class StridedSliceGrad(PrimitiveWithInfer):
"""
Performs grad of StridedSlice operation.

View File

@ -2659,6 +2659,53 @@ class SmoothL1Loss(PrimitiveWithInfer):
return prediction
class SoftMarginLoss(Primitive):
r"""
SoftMarginLoss operation.
Creates a criterion that optimizes a two-class classification
logistic loss between input tensor :math:`x` and target tensor :math:`y`
(containing 1 or -1).
.. math::
\text{loss}(x, y) = \sum_i \frac{\log(1 + \exp(-y[i]*x[i]))}{\text{x.nelement}()}
Args:
reduction (str): Apply specific reduction method to the output: 'none', 'mean', 'sum'. Default: "mean".
Inputs:
- **logits** (Tensor) - Predict data. Data type must be float16 or float32.
- **labels** (Tensor) - Ground truth data, with the same type and shape as `logits`.
Outputs:
Tensor or Scalar, if `reduction` is "none", its shape is the same as `logits`.
Otherwise, a scalar value will be returned.
Raises:
TypeError: If `logits` or `labels` is not a Tensor.
TypeError: If dtype of `logits` or `labels` is neither float16 nor float32.
ValueError: If shape of `logits` is not the same as `labels`.
ValueError: If `reduction` is not one of 'none', 'mean', 'sum'.
Supported Platforms:
``Ascend``
Examples:
>>> loss = ops.SoftMarginLoss()
>>> logits = Tensor(np.array([[0.3, 0.7], [0.5, 0.5]]), mindspore.float32)
>>> labels = Tensor(np.array([[-1, 1], [1, -1]]), mindspore.float32)
>>> output = loss(logits, labels)
>>> print(output)
0.6764238
"""
@prim_attr_register
def __init__(self, reduction="mean"):
"""Initialize SoftMarginLoss"""
self.init_prim_io_names(inputs=['predict', 'label'], outputs=['loss'])
self.reduction = validator.check_string(reduction, ['none', 'sum', 'mean'], 'reduction', self.name)
class L2Loss(PrimitiveWithInfer):
"""
Calculates half of the L2 norm of a tensor without using the `sqrt`.

View File

@ -2119,6 +2119,11 @@ test_case_nn_ops = [
'block': P.L2Loss(),
'desc_inputs': [Tensor(np.array([[1, 1], [2, 2], [3, 3], [4, 4]]), mstype.float16)],
'desc_bprop': []}),
('SoftMarginLoss', {
'block': P.SoftMarginLoss(reduction="none"),
'desc_inputs': [Tensor(np.array([[0.3, 0.7], [0.5, 0.5]]).astype(np.float32)),
Tensor(np.array([[-1, 1], [1, -1]]).astype(np.float32))],
'desc_bprop': [Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))]}),
('BCEWithLogitsLoss', {
'block': P.BCEWithLogitsLoss(),
'desc_inputs': [[3, 3], [3, 3], [3, 3], [3, 3]],