forked from mindspore-Ecosystem/mindspore
[feat] [assistant] [I54KH0] add new aicpu operator SmoothL1LossV2
This commit is contained in:
parent
49fba49aef
commit
572744594e
|
@ -27,11 +27,8 @@ mindspore.ops.smooth_l1_loss
|
||||||
|
|
||||||
其中, :math:`\text{beta}` 控制损失函数从二次元变为线性的point。默认值是1.0。 :math:`N` 为batch size。
|
其中, :math:`\text{beta}` 控制损失函数从二次元变为线性的point。默认值是1.0。 :math:`N` 为batch size。
|
||||||
|
|
||||||
.. note::
|
|
||||||
在Ascend上,目前不支持 `logits` 的数据类型是float64。
|
|
||||||
|
|
||||||
参数:
|
参数:
|
||||||
- **logits** (Tensor) - shape: :math:`(N, *)` ,其中 :math:`*` 表示任意数量的附加维度。数据类型为float16或float32, CPU和GPU后端还支持float64。
|
- **logits** (Tensor) - shape: :math:`(N, *)` ,其中 :math:`*` 表示任意数量的附加维度。数据类型为float16,float32和float64。
|
||||||
- **labels** (Tensor) - shape: :math:`(N, *)` ,与 `logits` 的shape和数据类型相同。
|
- **labels** (Tensor) - shape: :math:`(N, *)` ,与 `logits` 的shape和数据类型相同。
|
||||||
- **beta** (float) - 控制损失函数在L1Loss和L2Loss间变换的阈值。默认值:1.0。
|
- **beta** (float) - 控制损失函数在L1Loss和L2Loss间变换的阈值。默认值:1.0。
|
||||||
- **reduction** (str) - 缩减输出的方法。默认值:'none'。其他选项:'mean'和'sum'。
|
- **reduction** (str) - 缩减输出的方法。默认值:'none'。其他选项:'mean'和'sum'。
|
||||||
|
@ -45,4 +42,3 @@ mindspore.ops.smooth_l1_loss
|
||||||
- **TypeError** - `logits` 或 `labels` 的数据类型不是float16,float32和float64中的任一者。
|
- **TypeError** - `logits` 或 `labels` 的数据类型不是float16,float32和float64中的任一者。
|
||||||
- **ValueError** - `beta` 小于0。
|
- **ValueError** - `beta` 小于0。
|
||||||
- **ValueError** - `logits` 与 `labels` 的shape不同。
|
- **ValueError** - `logits` 与 `labels` 的shape不同。
|
||||||
- **TypeError** - Ascend后端不支持数据类型是float64的 `logits` 输入。
|
|
||||||
|
|
|
@ -188,6 +188,8 @@ constexpr auto kSliceGrad = "SliceGrad";
|
||||||
constexpr auto kStatelessDropOutGenMask = "StatelessDropOutGenMask";
|
constexpr auto kStatelessDropOutGenMask = "StatelessDropOutGenMask";
|
||||||
constexpr auto kRaggedTensorToTensor = "RaggedTensorToTensor";
|
constexpr auto kRaggedTensorToTensor = "RaggedTensorToTensor";
|
||||||
constexpr auto kAdaptiveMaxPool3D = "AdaptiveMaxPool3D";
|
constexpr auto kAdaptiveMaxPool3D = "AdaptiveMaxPool3D";
|
||||||
|
constexpr auto kSmoothL1Loss = "SmoothL1Loss";
|
||||||
|
constexpr auto kSmoothL1LossGrad = "SmoothL1LossGrad";
|
||||||
|
|
||||||
const std::set<std::string> kCpuKernelOps{kIdentity,
|
const std::set<std::string> kCpuKernelOps{kIdentity,
|
||||||
kMaskedSelect,
|
kMaskedSelect,
|
||||||
|
@ -311,6 +313,8 @@ const std::map<std::string, std::string> kOpNameToAicpuOpNameMap{
|
||||||
{kSampleDistortedBoundingBoxV2, "SampleDistortedBoundingBoxExt2"},
|
{kSampleDistortedBoundingBoxV2, "SampleDistortedBoundingBoxExt2"},
|
||||||
{kSparseSoftmaxCrossEntropyWithLogitsV2, "SparseSoftmaxCrossEntropyWithLogits"},
|
{kSparseSoftmaxCrossEntropyWithLogitsV2, "SparseSoftmaxCrossEntropyWithLogits"},
|
||||||
{kSparseToDenseV2, "SparseToDense"},
|
{kSparseToDenseV2, "SparseToDense"},
|
||||||
|
{kSmoothL1Loss, "SmoothL1LossV2"},
|
||||||
|
{kSmoothL1LossGrad, "SmoothL1LossGradV2"},
|
||||||
{kAvgPoolV1, "AvgPool"},
|
{kAvgPoolV1, "AvgPool"},
|
||||||
{kNonZero, "Where"},
|
{kNonZero, "Where"},
|
||||||
{kAvgPoolGradV1, "AvgPoolGrad"},
|
{kAvgPoolGradV1, "AvgPoolGrad"},
|
||||||
|
|
|
@ -50,10 +50,7 @@ std::string SmoothL1LossGrad::get_reduction() const {
|
||||||
namespace {
|
namespace {
|
||||||
abstract::ShapePtr SmoothL1LossGradInferShape(const PrimitivePtr &primitive,
|
abstract::ShapePtr SmoothL1LossGradInferShape(const PrimitivePtr &primitive,
|
||||||
const std::vector<AbstractBasePtr> &input_args) {
|
const std::vector<AbstractBasePtr> &input_args) {
|
||||||
MS_EXCEPTION_IF_NULL(primitive);
|
|
||||||
auto prim_name = primitive->name();
|
auto prim_name = primitive->name();
|
||||||
const int64_t input_num = 3;
|
|
||||||
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, prim_name);
|
|
||||||
auto prediction = CheckAndConvertUtils::CheckArgs<abstract::AbstractTensor>(prim_name, input_args, kInputIndex0);
|
auto prediction = CheckAndConvertUtils::CheckArgs<abstract::AbstractTensor>(prim_name, input_args, kInputIndex0);
|
||||||
auto target = CheckAndConvertUtils::CheckArgs<abstract::AbstractTensor>(prim_name, input_args, kInputIndex1);
|
auto target = CheckAndConvertUtils::CheckArgs<abstract::AbstractTensor>(prim_name, input_args, kInputIndex1);
|
||||||
abstract::CheckShapeSame(prim_name, prediction, target);
|
abstract::CheckShapeSame(prim_name, prediction, target);
|
||||||
|
@ -76,17 +73,18 @@ TypePtr SmoothL1LossGradInferType(const PrimitivePtr &prim, const std::vector<Ab
|
||||||
std::map<std::string, TypePtr> args;
|
std::map<std::string, TypePtr> args;
|
||||||
(void)args.emplace("prediction", input_args[kInputIndex0]->BuildType());
|
(void)args.emplace("prediction", input_args[kInputIndex0]->BuildType());
|
||||||
(void)args.emplace("target", input_args[kInputIndex1]->BuildType());
|
(void)args.emplace("target", input_args[kInputIndex1]->BuildType());
|
||||||
auto dloss_type = CheckAndConvertUtils::CheckTensorTypeSame(args, valid_types, prim->name());
|
(void)CheckAndConvertUtils::CheckTensorTypeSame(args, valid_types, prim->name());
|
||||||
return dloss_type;
|
return input_args[kInputIndex0]->BuildType();
|
||||||
}
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
MIND_API_OPERATOR_IMPL(SmoothL1LossGrad, BaseOperator);
|
MIND_API_OPERATOR_IMPL(SmoothL1LossGrad, BaseOperator);
|
||||||
AbstractBasePtr SmoothL1LossGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
AbstractBasePtr SmoothL1LossGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
const std::vector<AbstractBasePtr> &input_args) {
|
const std::vector<AbstractBasePtr> &input_args) {
|
||||||
for (auto item : input_args) {
|
MS_EXCEPTION_IF_NULL(primitive);
|
||||||
MS_EXCEPTION_IF_NULL(item);
|
auto prim_name = primitive->name();
|
||||||
}
|
const int64_t input_num = 3;
|
||||||
|
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, prim_name);
|
||||||
auto infer_type = SmoothL1LossGradInferType(primitive, input_args);
|
auto infer_type = SmoothL1LossGradInferType(primitive, input_args);
|
||||||
auto infer_shape = SmoothL1LossGradInferShape(primitive, input_args);
|
auto infer_shape = SmoothL1LossGradInferShape(primitive, input_args);
|
||||||
return abstract::MakeAbstract(infer_shape, infer_type);
|
return abstract::MakeAbstract(infer_shape, infer_type);
|
||||||
|
|
|
@ -50,10 +50,7 @@ std::string SmoothL1Loss::get_reduction() const {
|
||||||
namespace {
|
namespace {
|
||||||
abstract::ShapePtr SmoothL1LossInferShape(const PrimitivePtr &primitive,
|
abstract::ShapePtr SmoothL1LossInferShape(const PrimitivePtr &primitive,
|
||||||
const std::vector<AbstractBasePtr> &input_args) {
|
const std::vector<AbstractBasePtr> &input_args) {
|
||||||
MS_EXCEPTION_IF_NULL(primitive);
|
|
||||||
auto prim_name = primitive->name();
|
auto prim_name = primitive->name();
|
||||||
const int64_t input_num = 2;
|
|
||||||
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, prim_name);
|
|
||||||
auto prediction = CheckAndConvertUtils::CheckArgs<abstract::AbstractTensor>(prim_name, input_args, kInputIndex0);
|
auto prediction = CheckAndConvertUtils::CheckArgs<abstract::AbstractTensor>(prim_name, input_args, kInputIndex0);
|
||||||
auto target = CheckAndConvertUtils::CheckArgs<abstract::AbstractTensor>(prim_name, input_args, kInputIndex1);
|
auto target = CheckAndConvertUtils::CheckArgs<abstract::AbstractTensor>(prim_name, input_args, kInputIndex1);
|
||||||
auto prediction_shape = prediction->shape();
|
auto prediction_shape = prediction->shape();
|
||||||
|
@ -69,33 +66,29 @@ abstract::ShapePtr SmoothL1LossInferShape(const PrimitivePtr &primitive,
|
||||||
if (reduction == kNone) {
|
if (reduction == kNone) {
|
||||||
return prediction_shape;
|
return prediction_shape;
|
||||||
} else {
|
} else {
|
||||||
ShapeVector shape_out{1};
|
ShapeVector shape_out{};
|
||||||
return std::make_shared<abstract::Shape>(shape_out);
|
return std::make_shared<abstract::Shape>(shape_out);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
TypePtr SmoothL1LossInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
TypePtr SmoothL1LossInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
|
||||||
// Infer type
|
// Infer type
|
||||||
std::set<TypePtr> valid_types{};
|
const std::set<TypePtr> valid_types = {kFloat16, kFloat32, kFloat64};
|
||||||
auto context = MsContext::GetInstance();
|
|
||||||
MS_EXCEPTION_IF_NULL(context);
|
|
||||||
bool is_ascend = (context->get_param<std::string>(MS_CTX_DEVICE_TARGET) == kAscendDevice);
|
|
||||||
if (is_ascend) {
|
|
||||||
valid_types = {kFloat16, kFloat32};
|
|
||||||
} else {
|
|
||||||
valid_types = {kFloat16, kFloat32, kFloat64};
|
|
||||||
}
|
|
||||||
|
|
||||||
std::map<std::string, TypePtr> args;
|
std::map<std::string, TypePtr> args;
|
||||||
(void)args.emplace("scale", input_args[kInputIndex0]->BuildType());
|
(void)args.emplace("scale", input_args[kInputIndex0]->BuildType());
|
||||||
(void)args.emplace("bias", input_args[kInputIndex1]->BuildType());
|
(void)args.emplace("bias", input_args[kInputIndex1]->BuildType());
|
||||||
auto prediction_type = CheckAndConvertUtils::CheckTensorTypeSame(args, valid_types, prim->name());
|
(void)CheckAndConvertUtils::CheckTensorTypeSame(args, valid_types, prim->name());
|
||||||
return prediction_type;
|
return input_args[kInputIndex0]->BuildType();
|
||||||
}
|
}
|
||||||
} // namespace
|
} // namespace
|
||||||
|
|
||||||
AbstractBasePtr SmoothL1LossInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
AbstractBasePtr SmoothL1LossInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
|
||||||
const std::vector<AbstractBasePtr> &input_args) {
|
const std::vector<AbstractBasePtr> &input_args) {
|
||||||
|
MS_EXCEPTION_IF_NULL(primitive);
|
||||||
|
auto prim_name = primitive->name();
|
||||||
|
const int64_t input_num = 2;
|
||||||
|
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, prim_name);
|
||||||
auto infer_type = SmoothL1LossInferType(primitive, input_args);
|
auto infer_type = SmoothL1LossInferType(primitive, input_args);
|
||||||
auto infer_shape = SmoothL1LossInferShape(primitive, input_args);
|
auto infer_shape = SmoothL1LossInferShape(primitive, input_args);
|
||||||
return abstract::MakeAbstract(infer_shape, infer_type);
|
return abstract::MakeAbstract(infer_shape, infer_type);
|
||||||
|
|
|
@ -0,0 +1,35 @@
|
||||||
|
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""SmoothL1Loss op"""
|
||||||
|
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
|
||||||
|
|
||||||
|
smooth_l1_loss_op_info = AiCPURegOp("SmoothL1Loss") \
|
||||||
|
.fusion_type("OPAQUE") \
|
||||||
|
.attr("sigma", "float") \
|
||||||
|
.attr("reduction", "str") \
|
||||||
|
.input(0, "prediction", "required") \
|
||||||
|
.input(1, "target", "required") \
|
||||||
|
.output(0, "output", "required") \
|
||||||
|
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
|
||||||
|
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
|
||||||
|
.dtype_format(DataType.F64_Default, DataType.F64_Default, DataType.F64_Default) \
|
||||||
|
.get_op_info()
|
||||||
|
|
||||||
|
|
||||||
|
@op_info_register(smooth_l1_loss_op_info)
|
||||||
|
def _smooth_l1_loss_aicpu():
|
||||||
|
"""SmoothL1Loss AiCPU register"""
|
||||||
|
return
|
|
@ -0,0 +1,37 @@
|
||||||
|
# Copyright 2022 Huawei Technologies Co., Ltd
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ============================================================================
|
||||||
|
|
||||||
|
"""SmoothL1LossGrad op"""
|
||||||
|
from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType
|
||||||
|
|
||||||
|
smooth_l1_loss_grad_op_info = AiCPURegOp("SmoothL1LossGrad") \
|
||||||
|
.fusion_type("OPAQUE") \
|
||||||
|
.attr("sigma", "float") \
|
||||||
|
.attr("reduction", "str") \
|
||||||
|
.input(0, "prediction", "required") \
|
||||||
|
.input(1, "target", "required") \
|
||||||
|
.input(2, "dout", "required") \
|
||||||
|
.output(0, "output", "required") \
|
||||||
|
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
|
||||||
|
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
|
||||||
|
.dtype_format(DataType.F64_Default, DataType.F64_Default, DataType.F64_Default, DataType.F64_Default) \
|
||||||
|
.get_op_info()
|
||||||
|
|
||||||
|
|
||||||
|
@op_info_register(smooth_l1_loss_grad_op_info)
|
||||||
|
def _smooth_l1_loss_grad_aicpu():
|
||||||
|
"""SmoothL1LossGrad AiCPU register"""
|
||||||
|
return
|
||||||
|
|
|
@ -2946,9 +2946,6 @@ def smooth_l1_loss(logits, labels, beta=1.0, reduction='none'):
|
||||||
Here :math:`\text{beta}` controls the point where the loss function changes from quadratic to linear.
|
Here :math:`\text{beta}` controls the point where the loss function changes from quadratic to linear.
|
||||||
Its default value is 1.0. :math:`N` is the batch size.
|
Its default value is 1.0. :math:`N` is the batch size.
|
||||||
|
|
||||||
Note:
|
|
||||||
For Ascend platform, the float64 data type of `logits` is not support now.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
logits (Tensor): Tensor of shape :math:`(N, *)` where :math:`*` means, any number of additional dimensions.
|
logits (Tensor): Tensor of shape :math:`(N, *)` where :math:`*` means, any number of additional dimensions.
|
||||||
labels (Tensor): Ground truth data, tensor of shape :math:`(N, *)`, same shape and dtype as the `logits`.
|
labels (Tensor): Ground truth data, tensor of shape :math:`(N, *)`, same shape and dtype as the `logits`.
|
||||||
|
@ -2963,10 +2960,9 @@ def smooth_l1_loss(logits, labels, beta=1.0, reduction='none'):
|
||||||
Raises:
|
Raises:
|
||||||
TypeError: If `beta` is not a float.
|
TypeError: If `beta` is not a float.
|
||||||
ValueError: If `reduction` is not one of 'none', 'mean', 'sum'.
|
ValueError: If `reduction` is not one of 'none', 'mean', 'sum'.
|
||||||
TypeError: If dtype of `logits` or `labels` is neither float16 nor float32.
|
TypeError: If dtype of `logits` or `labels` is not one of float16, float32, float64.
|
||||||
ValueError: If `beta` is less than or equal to 0.
|
ValueError: If `beta` is less than or equal to 0.
|
||||||
ValueError: If shape of `logits` is not the same as `labels`.
|
ValueError: If shape of `logits` is not the same as `labels`.
|
||||||
TypeError: The float64 data type of `logits` is support on Ascend platform.
|
|
||||||
|
|
||||||
Supported Platforms:
|
Supported Platforms:
|
||||||
``Ascend`` ``GPU`` ``CPU``
|
``Ascend`` ``GPU`` ``CPU``
|
||||||
|
|
|
@ -2129,6 +2129,7 @@ class SmoothL1LossGrad(Primitive):
|
||||||
|
|
||||||
@prim_attr_register
|
@prim_attr_register
|
||||||
def __init__(self, beta=1.0, reduction='none'):
|
def __init__(self, beta=1.0, reduction='none'):
|
||||||
|
self.add_prim_attr('sigma', self.beta)
|
||||||
self.reduction = validator.check_string(
|
self.reduction = validator.check_string(
|
||||||
reduction, ['none', 'sum', 'mean'], 'reduction', self.name)
|
reduction, ['none', 'sum', 'mean'], 'reduction', self.name)
|
||||||
|
|
||||||
|
|
|
@ -2993,6 +2993,7 @@ class SmoothL1Loss(Primitive):
|
||||||
validator.check('beta', beta, '', 0, Rel.GT, self.name)
|
validator.check('beta', beta, '', 0, Rel.GT, self.name)
|
||||||
validator.check_string(
|
validator.check_string(
|
||||||
reduction, ['none', 'sum', 'mean'], 'reduction', self.name)
|
reduction, ['none', 'sum', 'mean'], 'reduction', self.name)
|
||||||
|
self.add_prim_attr('sigma', self.beta)
|
||||||
self.init_prim_io_names(inputs=['prediction', 'target'], outputs=['output'])
|
self.init_prim_io_names(inputs=['prediction', 'target'], outputs=['output'])
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue