forked from mindspore-Ecosystem/mindspore
!20680 [assistant][ops] Add New SoftMarginLoss and SoftMarginLossGrad
Merge pull request !20680 from 孟权令/SoftMarginLoss
This commit is contained in:
commit
f154db1bb0
|
@ -312,6 +312,8 @@ inline const PrimitivePtr kPrimBinaryCrossEntropy = std::make_shared<Primitive>(
|
|||
inline const PrimitivePtr kPrimBinaryCrossEntropyGrad = std::make_shared<Primitive>("BinaryCrossEntropyGrad");
|
||||
inline const PrimitivePtr kPrimSmoothL1Loss = std::make_shared<Primitive>("SmoothL1Loss");
|
||||
inline const PrimitivePtr kPrimSmoothL1LossGrad = std::make_shared<Primitive>("SmoothL1LossGrad");
|
||||
inline const PrimitivePtr kPrimSoftMarginLoss = std::make_shared<Primitive>("SoftMarginLoss");
|
||||
inline const PrimitivePtr kPrimSoftMarginLossGrad = std::make_shared<Primitive>("SoftMarginLossGrad");
|
||||
inline const PrimitivePtr kPrimSoftmaxCrossEntropyWithLogits =
|
||||
std::make_shared<Primitive>("SoftmaxCrossEntropyWithLogits");
|
||||
inline const PrimitivePtr kPrimSigmoidCrossEntropyWithLogits =
|
||||
|
|
|
@ -0,0 +1,62 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "ops/grad/soft_margin_loss_grad.h"
|
||||
|
||||
#include "ops/op_utils.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "utils/tensor_construct_utils.h"
|
||||
#include "abstract/primitive_infer_map.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
constexpr size_t kInputSize = 3;
|
||||
abstract::ShapePtr SoftMarginLossGradInferShape(const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto op_name = primitive->name();
|
||||
CheckAndConvertUtils::CheckInteger("input numbers", SizeToLong(input_args.size()), kEqual, kInputSize, op_name);
|
||||
auto predict = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
|
||||
auto label = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape];
|
||||
auto dout = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[2]->BuildShape())[kShape];
|
||||
CheckAndConvertUtils::Check("logits shape", predict, kEqual, "labels shape", label, op_name, ValueError);
|
||||
if (dout.size() > 1) {
|
||||
CheckAndConvertUtils::Check("logits shape", predict, kEqual, "dout shape", dout, op_name, ValueError);
|
||||
}
|
||||
return std::make_shared<abstract::Shape>(predict);
|
||||
}
|
||||
|
||||
TypePtr SoftMarginLossGradInferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
auto op_name = primitive->name();
|
||||
CheckAndConvertUtils::CheckInteger("input numbers", SizeToLong(input_args.size()), kEqual, kInputSize, op_name);
|
||||
const std::set<TypePtr> valid_types = {kFloat16, kFloat32};
|
||||
std::map<std::string, TypePtr> types;
|
||||
types.emplace("logits", input_args[0]->BuildType());
|
||||
types.emplace("labels", input_args[1]->BuildType());
|
||||
types.emplace("dout", input_args[2]->BuildType());
|
||||
CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, op_name);
|
||||
return CheckAndConvertUtils::CheckTensorTypeValid("logits", input_args[0]->BuildType(), valid_types, op_name);
|
||||
}
|
||||
} // namespace
|
||||
|
||||
AbstractBasePtr SoftMarginLossGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
return abstract::MakeAbstract(SoftMarginLossGradInferShape(primitive, input_args),
|
||||
SoftMarginLossGradInferType(primitive, input_args));
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(SoftMarginLossGrad, prim::kPrimSoftMarginLossGrad, SoftMarginLossGradInfer, nullptr, true);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,43 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_OPS_SOFT_MARGIN_LOSS_GRAD_H_
|
||||
#define MINDSPORE_CORE_OPS_SOFT_MARGIN_LOSS_GRAD_H_
|
||||
#include <memory>
|
||||
#include <map>
|
||||
#include <vector>
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include "ops/primitive_c.h"
|
||||
#include "abstract/abstract_value.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNameSoftMarginLossGrad = "SoftMarginLossGrad";
|
||||
class SoftMarginLossGrad : public PrimitiveC {
|
||||
public:
|
||||
SoftMarginLossGrad() : PrimitiveC(kNameSoftMarginLossGrad) { InitIOName({"predict", "label", "dout"}, {"gradient"}); }
|
||||
~SoftMarginLossGrad() = default;
|
||||
MS_DECLARE_PARENT(SoftMarginLossGrad, PrimitiveC);
|
||||
};
|
||||
|
||||
AbstractBasePtr SoftMarginLossGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args);
|
||||
using PrimSoftMarginLossGradPtr = std::shared_ptr<SoftMarginLossGrad>;
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
#endif // MINDSPORE_CORE_OPS_SOFT_MARGIN_LOSS_GRAD_H_
|
|
@ -0,0 +1,63 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
#include "ops/soft_margin_loss.h"
|
||||
|
||||
#include "ops/op_utils.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
#include "utils/tensor_construct_utils.h"
|
||||
#include "abstract/primitive_infer_map.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
namespace {
|
||||
constexpr size_t kInputSize = 2;
|
||||
abstract::ShapePtr SoftMarginLossInferShape(const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
MS_EXCEPTION_IF_NULL(primitive);
|
||||
auto op_name = primitive->name();
|
||||
CheckAndConvertUtils::CheckInteger("input numbers", SizeToLong(input_args.size()), kEqual, kInputSize, op_name);
|
||||
auto predict = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
|
||||
auto label = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[1]->BuildShape())[kShape];
|
||||
CheckAndConvertUtils::Check("logits shape", predict, kEqual, "labels shape", label, op_name, ValueError);
|
||||
auto out_shape = predict;
|
||||
int64_t reduction;
|
||||
CheckAndConvertUtils::GetReductionEnumValue(primitive->GetAttr(kReduction), &reduction);
|
||||
if (reduction == REDUCTION_SUM || reduction == MEAN) {
|
||||
out_shape.resize(0);
|
||||
}
|
||||
return std::make_shared<abstract::Shape>(out_shape);
|
||||
}
|
||||
|
||||
TypePtr SoftMarginLossInferType(const PrimitivePtr &primitive, const std::vector<AbstractBasePtr> &input_args) {
|
||||
auto op_name = primitive->name();
|
||||
CheckAndConvertUtils::CheckInteger("input numbers", SizeToLong(input_args.size()), kEqual, kInputSize, op_name);
|
||||
const std::set<TypePtr> valid_types = {kFloat16, kFloat32};
|
||||
std::map<std::string, TypePtr> types;
|
||||
types.emplace("logits", input_args[0]->BuildType());
|
||||
types.emplace("labels", input_args[1]->BuildType());
|
||||
CheckAndConvertUtils::CheckTensorTypeSame(types, valid_types, op_name);
|
||||
return input_args[0]->BuildType();
|
||||
}
|
||||
} // namespace
|
||||
|
||||
AbstractBasePtr SoftMarginLossInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args) {
|
||||
return abstract::MakeAbstract(SoftMarginLossInferShape(primitive, input_args),
|
||||
SoftMarginLossInferType(primitive, input_args));
|
||||
}
|
||||
REGISTER_PRIMITIVE_EVAL_IMPL(SoftMarginLoss, prim::kPrimSoftMarginLoss, SoftMarginLossInfer, nullptr, true);
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
|
@ -0,0 +1,44 @@
|
|||
/**
|
||||
* Copyright 2021 Huawei Technologies Co., Ltd
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
* You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
#ifndef MINDSPORE_CORE_OPS_SOFT_MARGIN_LOSS_H_
|
||||
#define MINDSPORE_CORE_OPS_SOFT_MARGIN_LOSS_H_
|
||||
#include <memory>
|
||||
#include <map>
|
||||
#include <vector>
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include "ops/primitive_c.h"
|
||||
#include "abstract/abstract_value.h"
|
||||
#include "utils/check_convert_utils.h"
|
||||
|
||||
namespace mindspore {
|
||||
namespace ops {
|
||||
constexpr auto kNameSoftMarginLoss = "SoftMarginLoss";
|
||||
class SoftMarginLoss : public PrimitiveC {
|
||||
public:
|
||||
SoftMarginLoss() : PrimitiveC(kNameSoftMarginLoss) { InitIOName({"predict", "label"}, {"loss"}); }
|
||||
~SoftMarginLoss() = default;
|
||||
MS_DECLARE_PARENT(SoftMarginLoss, PrimitiveC);
|
||||
};
|
||||
|
||||
AbstractBasePtr SoftMarginLossInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||
const std::vector<AbstractBasePtr> &input_args);
|
||||
using PrimSoftMarginLossPtr = std::shared_ptr<SoftMarginLoss>;
|
||||
} // namespace ops
|
||||
} // namespace mindspore
|
||||
|
||||
#endif // MINDSPORE_CORE_OPS_SOFT_MARGIN_LOSS_H_
|
|
@ -175,6 +175,21 @@ void CheckAndConvertUtils::GetPadModEnumValue(const ValuePtr &value, int64_t *en
|
|||
}
|
||||
}
|
||||
|
||||
void CheckAndConvertUtils::GetReductionEnumValue(const ValuePtr &value, int64_t *enum_value) {
|
||||
MS_EXCEPTION_IF_NULL(value);
|
||||
if (value->isa<StringImm>()) {
|
||||
auto attr_value_str = GetValue<std::string>(value);
|
||||
|
||||
std::map<std::string, int64_t> pad_map = ReductionToEnumMap;
|
||||
if (pad_map.find(attr_value_str) == pad_map.end()) {
|
||||
MS_LOG(EXCEPTION) << "Invalid pad mode " << attr_value_str << " use pad, valid or same";
|
||||
}
|
||||
*enum_value = pad_map[attr_value_str];
|
||||
} else {
|
||||
*enum_value = GetValue<int64_t>(value);
|
||||
}
|
||||
}
|
||||
|
||||
AttrConverterPair CheckAndConvertUtils::GetAttrConvertPair(const std::string &op_type, const std::string &attr_name) {
|
||||
AttrConverterPair attr_pair;
|
||||
if (op_type.empty() || attr_name.empty()) {
|
||||
|
|
|
@ -297,6 +297,7 @@ class CheckAndConvertUtils {
|
|||
static AttrConverterPair GetAttrConvertPair(const std::string &op_type, const std::string &attr_name);
|
||||
static bool GetDataFormatEnumValue(const ValuePtr &value, int64_t *enum_value);
|
||||
static void GetPadModEnumValue(const ValuePtr &value, int64_t *enum_value, bool is_upper = false);
|
||||
static void GetReductionEnumValue(const ValuePtr &value, int64_t *enum_value);
|
||||
static bool CheckIrAttrtoOpAttr(const std::string &op_type, const std::string &attr_name, ValuePtr *const value);
|
||||
static void CheckSummaryParam(const AbstractBasePtr &name, const AbstractBasePtr &value,
|
||||
const std::string &class_name);
|
||||
|
|
|
@ -19,13 +19,13 @@ Cells of loss function. Loss function in machine learning is the target of the m
|
|||
It shows how well the model works on a dataset and the optimization target which the optimizer is searching.
|
||||
"""
|
||||
|
||||
from .loss import LossBase, L1Loss, MSELoss, SmoothL1Loss, FocalLoss,\
|
||||
from .loss import LossBase, L1Loss, MSELoss, SmoothL1Loss, SoftMarginLoss, FocalLoss,\
|
||||
SoftmaxCrossEntropyWithLogits, BCELoss, CosineEmbeddingLoss, \
|
||||
SampledSoftmaxLoss, DiceLoss, BCEWithLogitsLoss, MultiClassDiceLoss,\
|
||||
RMSELoss, MAELoss
|
||||
|
||||
|
||||
__all__ = ['LossBase', 'L1Loss', 'MSELoss', 'SmoothL1Loss', 'FocalLoss',
|
||||
__all__ = ['LossBase', 'L1Loss', 'MSELoss', 'SmoothL1Loss', 'SoftMarginLoss', 'FocalLoss',
|
||||
'SoftmaxCrossEntropyWithLogits', 'BCELoss', 'BCEWithLogitsLoss',
|
||||
'CosineEmbeddingLoss', 'SampledSoftmaxLoss', 'DiceLoss', 'MultiClassDiceLoss',
|
||||
'RMSELoss', 'MAELoss']
|
||||
|
|
|
@ -436,6 +436,53 @@ class SmoothL1Loss(LossBase):
|
|||
return self.smooth_l1_loss(base, target)
|
||||
|
||||
|
||||
class SoftMarginLoss(LossBase):
|
||||
r"""
|
||||
A loss class for two-class classification problems.
|
||||
|
||||
SoftMarginLoss creates a criterion that optimizes a two-class classification
|
||||
logistic loss between input tensor :math:`x` and target tensor :math:`y`
|
||||
(containing 1 or -1).
|
||||
|
||||
.. math::
|
||||
\text{loss}(x, y) = \sum_i \frac{\log(1 + \exp(-y[i]*x[i]))}{\text{x.nelement}()}
|
||||
|
||||
Args:
|
||||
reduction (str): Apply specific reduction method to the output: 'none', 'mean', 'sum'. Default: "mean".
|
||||
|
||||
Inputs:
|
||||
- **logits** (Tensor) - Predict data. Data type must be float16 or float32.
|
||||
- **labels** (Tensor) - Ground truth data, with the same type and shape as `logits`.
|
||||
|
||||
Outputs:
|
||||
Tensor or Scalar, if `reduction` is "none", its shape is the same as `logits`.
|
||||
Otherwise, a scalar value will be returned.
|
||||
|
||||
Raises:
|
||||
TypeError: If `logits` or `labels` is not a Tensor.
|
||||
TypeError: If dtype of `logits` or `labels` is neither float16 nor float32.
|
||||
ValueError: If shape of `logits` is not the same as `labels`.
|
||||
ValueError: If `reduction` is not one of 'none', 'mean', 'sum'.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend``
|
||||
|
||||
Examples:
|
||||
>>> loss = ops.SoftMarginLoss()
|
||||
>>> logits = Tensor(np.array([[0.3, 0.7], [0.5, 0.5]]), mindspore.float32)
|
||||
>>> labels = Tensor(np.array([[-1, 1], [1, -1]]), mindspore.float32)
|
||||
>>> output = loss(logits, labels)
|
||||
>>> print(output)
|
||||
0.6764238
|
||||
"""
|
||||
def __init__(self, reduction='mean'):
|
||||
super(SoftMarginLoss, self).__init__()
|
||||
self.soft_margin_loss = P.SoftMarginLoss(reduction)
|
||||
|
||||
def construct(self, base, target):
|
||||
return self.soft_margin_loss(base, target)
|
||||
|
||||
|
||||
class SoftmaxCrossEntropyWithLogits(LossBase):
|
||||
r"""
|
||||
Computes softmax cross entropy between logits and labels.
|
||||
|
@ -1282,10 +1329,10 @@ class FocalLoss(LossBase):
|
|||
convert_weight = self.squeeze(convert_weight)
|
||||
log_probability = log_probability * convert_weight
|
||||
|
||||
weight = F.pows(-probability + 1.0, self.gamma)
|
||||
weight = F.pows(-1 * probability + 1.0, self.gamma)
|
||||
if target.shape[1] == 1:
|
||||
loss = (-weight * log_probability).mean(axis=1)
|
||||
loss = (-1 * weight * log_probability).mean(axis=1)
|
||||
else:
|
||||
loss = (-weight * targets * log_probability).mean(axis=-1)
|
||||
loss = (-1 * weight * targets * log_probability).mean(axis=-1)
|
||||
|
||||
return self.get_loss(loss)
|
||||
|
|
|
@ -34,6 +34,19 @@ def get_bprop_ctc_loss_v2(self):
|
|||
return bprop
|
||||
|
||||
|
||||
@bprop_getters.register(P.SoftMarginLoss)
|
||||
def get_bprop_soft_margin_loss(self):
|
||||
"""Grad definition for `SoftMarginLoss` operation."""
|
||||
grad = G.SoftMarginLossGrad(reduction=self.reduction)
|
||||
|
||||
def bprop(predict, label, out, dout):
|
||||
dx = grad(predict, label, dout)
|
||||
dy = grad(label, predict, dout)
|
||||
return dx, dy
|
||||
|
||||
return bprop
|
||||
|
||||
|
||||
@bprop_getters.register(P.SoftShrink)
|
||||
def get_bprop_softshrink(self):
|
||||
"""Grad definition for `SoftShrink` operation."""
|
||||
|
|
|
@ -220,6 +220,8 @@ from .arg_max_with_value import _arg_max_with_value_tbe
|
|||
from .arg_min_with_value import _arg_min_with_value_tbe
|
||||
from .smooth_l1_loss import _smooth_l1_loss_tbe
|
||||
from .smooth_l1_loss_grad import _smooth_l1_loss_grad_tbe
|
||||
from .soft_margin_loss import _soft_margin_loss_tbe
|
||||
from .soft_margin_loss_grad import _soft_margin_loss_grad_tbe
|
||||
from .fused_mul_add import _fused_mul_add_tbe
|
||||
from .fused_mul_add_n import _fused_mul_add_n_tbe
|
||||
from .fused_mul_apply_momentum import _fused_mul_apply_momentum_tbe
|
||||
|
|
|
@ -0,0 +1,38 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""SoftMarginLoss op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
soft_margin_loss_op_info = TBERegOp("SoftMarginLoss") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("soft_margin_loss.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("soft_margin_loss") \
|
||||
.partial_flag(True) \
|
||||
.attr("reduction", "optional", "str", "mean,none,sum", "mean") \
|
||||
.input(0, "input_x", False, "required", "all") \
|
||||
.input(1, "input_y", False, "required", "all") \
|
||||
.output(0, "output_z", False, "required", "all") \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
|
||||
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(soft_margin_loss_op_info)
|
||||
def _soft_margin_loss_tbe():
|
||||
"""SoftMarginLoss TBE register"""
|
||||
return
|
|
@ -0,0 +1,39 @@
|
|||
# Copyright 2021 Huawei Technologies Co., Ltd
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ============================================================================
|
||||
|
||||
"""SoftMarginLossGrad op"""
|
||||
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
|
||||
|
||||
soft_margin_loss_grad_op_info = TBERegOp("SoftMarginLossGrad") \
|
||||
.fusion_type("OPAQUE") \
|
||||
.async_flag(False) \
|
||||
.binfile_name("soft_margin_loss_grad.so") \
|
||||
.compute_cost(10) \
|
||||
.kernel_name("soft_margin_loss_grad") \
|
||||
.partial_flag(True) \
|
||||
.attr("reduction", "optional", "str", "mean,none,sum", "mean") \
|
||||
.input(0, "predict", False, "required", "all") \
|
||||
.input(1, "label", False, "required", "all") \
|
||||
.input(2, "dout", False, "required", "all") \
|
||||
.output(0, "gradient", False, "required", "all") \
|
||||
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
|
||||
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
|
||||
.get_op_info()
|
||||
|
||||
|
||||
@op_info_register(soft_margin_loss_grad_op_info)
|
||||
def _soft_margin_loss_grad_tbe():
|
||||
"""SoftMarginLossGrad TBE register"""
|
||||
return
|
|
@ -78,7 +78,7 @@ from .nn_ops import (LSTM, SGD, Adam, FusedSparseAdam, FusedSparseLazyAdam, Adam
|
|||
MaxPoolWithArgmax, OneHot, Pad, MirrorPad, Mish, PReLU, ReLU, ReLU6, ReLUV2, HSwish, HSigmoid,
|
||||
ResizeBilinear, Sigmoid, SeLU, HShrink,
|
||||
SigmoidCrossEntropyWithLogits, NLLLoss, BCEWithLogitsLoss,
|
||||
SmoothL1Loss, Softmax, Softsign, Softplus, LRN, RNNTLoss, DynamicRNN, DynamicGRUV2,
|
||||
SmoothL1Loss, SoftMarginLoss, Softmax, Softsign, Softplus, LRN, RNNTLoss, DynamicRNN, DynamicGRUV2,
|
||||
SoftmaxCrossEntropyWithLogits, ROIAlign,
|
||||
SparseSoftmaxCrossEntropyWithLogits, Tanh,
|
||||
TopK, BinaryCrossEntropy, KLDivLoss, SparseApplyAdagrad, LARSUpdate, ApplyFtrl, SparseApplyFtrl,
|
||||
|
@ -286,6 +286,7 @@ __all__ = [
|
|||
'FloatStatus',
|
||||
'Reciprocal',
|
||||
'SmoothL1Loss',
|
||||
'SoftMarginLoss',
|
||||
'L2Loss',
|
||||
'CTCLoss',
|
||||
'CTCGreedyDecoder',
|
||||
|
|
|
@ -1831,6 +1831,15 @@ class SmoothL1LossGrad(PrimitiveWithInfer):
|
|||
return dloss
|
||||
|
||||
|
||||
class SoftMarginLossGrad(Primitive):
|
||||
"""Computes gradient for prediction on SoftMarginLoss."""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self, reduction="mean"):
|
||||
self.init_prim_io_names(inputs=['predict', 'label', "dout"], outputs=['gradient'])
|
||||
self.reduction = validator.check_string(reduction, ['none', 'sum', 'mean'], 'reduction', self.name)
|
||||
|
||||
|
||||
class StridedSliceGrad(PrimitiveWithInfer):
|
||||
"""
|
||||
Performs grad of StridedSlice operation.
|
||||
|
|
|
@ -2659,6 +2659,53 @@ class SmoothL1Loss(PrimitiveWithInfer):
|
|||
return prediction
|
||||
|
||||
|
||||
class SoftMarginLoss(Primitive):
|
||||
r"""
|
||||
SoftMarginLoss operation.
|
||||
|
||||
Creates a criterion that optimizes a two-class classification
|
||||
logistic loss between input tensor :math:`x` and target tensor :math:`y`
|
||||
(containing 1 or -1).
|
||||
|
||||
.. math::
|
||||
\text{loss}(x, y) = \sum_i \frac{\log(1 + \exp(-y[i]*x[i]))}{\text{x.nelement}()}
|
||||
|
||||
Args:
|
||||
reduction (str): Apply specific reduction method to the output: 'none', 'mean', 'sum'. Default: "mean".
|
||||
|
||||
Inputs:
|
||||
- **logits** (Tensor) - Predict data. Data type must be float16 or float32.
|
||||
- **labels** (Tensor) - Ground truth data, with the same type and shape as `logits`.
|
||||
|
||||
Outputs:
|
||||
Tensor or Scalar, if `reduction` is "none", its shape is the same as `logits`.
|
||||
Otherwise, a scalar value will be returned.
|
||||
|
||||
Raises:
|
||||
TypeError: If `logits` or `labels` is not a Tensor.
|
||||
TypeError: If dtype of `logits` or `labels` is neither float16 nor float32.
|
||||
ValueError: If shape of `logits` is not the same as `labels`.
|
||||
ValueError: If `reduction` is not one of 'none', 'mean', 'sum'.
|
||||
|
||||
Supported Platforms:
|
||||
``Ascend``
|
||||
|
||||
Examples:
|
||||
>>> loss = ops.SoftMarginLoss()
|
||||
>>> logits = Tensor(np.array([[0.3, 0.7], [0.5, 0.5]]), mindspore.float32)
|
||||
>>> labels = Tensor(np.array([[-1, 1], [1, -1]]), mindspore.float32)
|
||||
>>> output = loss(logits, labels)
|
||||
>>> print(output)
|
||||
0.6764238
|
||||
"""
|
||||
|
||||
@prim_attr_register
|
||||
def __init__(self, reduction="mean"):
|
||||
"""Initialize SoftMarginLoss"""
|
||||
self.init_prim_io_names(inputs=['predict', 'label'], outputs=['loss'])
|
||||
self.reduction = validator.check_string(reduction, ['none', 'sum', 'mean'], 'reduction', self.name)
|
||||
|
||||
|
||||
class L2Loss(PrimitiveWithInfer):
|
||||
"""
|
||||
Calculates half of the L2 norm of a tensor without using the `sqrt`.
|
||||
|
|
|
@ -2119,6 +2119,11 @@ test_case_nn_ops = [
|
|||
'block': P.L2Loss(),
|
||||
'desc_inputs': [Tensor(np.array([[1, 1], [2, 2], [3, 3], [4, 4]]), mstype.float16)],
|
||||
'desc_bprop': []}),
|
||||
('SoftMarginLoss', {
|
||||
'block': P.SoftMarginLoss(reduction="none"),
|
||||
'desc_inputs': [Tensor(np.array([[0.3, 0.7], [0.5, 0.5]]).astype(np.float32)),
|
||||
Tensor(np.array([[-1, 1], [1, -1]]).astype(np.float32))],
|
||||
'desc_bprop': [Tensor(np.array([[1, 1], [1, 1]]).astype(np.float32))]}),
|
||||
('BCEWithLogitsLoss', {
|
||||
'block': P.BCEWithLogitsLoss(),
|
||||
'desc_inputs': [[3, 3], [3, 3], [3, 3], [3, 3]],
|
||||
|
|
Loading…
Reference in New Issue