From 572744594e13692eb9ac0fdc3485cb6cc3d7b300 Mon Sep 17 00:00:00 2001 From: misitetong <995906889@qq.com> Date: Wed, 23 Nov 2022 14:39:29 +0800 Subject: [PATCH] [feat] [assistant] [I54KH0] add new aicpu operator SmoothL1LossV2 --- .../ops/mindspore.ops.func_smooth_l1_loss.rst | 6 +-- .../device/ascend/kernel/aicpu/aicpu_util.h | 4 ++ .../core/ops/grad/smooth_l1_loss_grad.cc | 14 +++---- mindspore/core/ops/smooth_l1_loss.cc | 23 ++++-------- .../ops/_op_impl/aicpu/smooth_l1_loss.py | 35 ++++++++++++++++++ .../ops/_op_impl/aicpu/smooth_l1_loss_grad.py | 37 +++++++++++++++++++ .../python/mindspore/ops/function/nn_func.py | 6 +-- .../mindspore/ops/operations/_grad_ops.py | 1 + .../python/mindspore/ops/operations/nn_ops.py | 1 + 9 files changed, 94 insertions(+), 33 deletions(-) create mode 100644 mindspore/python/mindspore/ops/_op_impl/aicpu/smooth_l1_loss.py create mode 100644 mindspore/python/mindspore/ops/_op_impl/aicpu/smooth_l1_loss_grad.py diff --git a/docs/api/api_python/ops/mindspore.ops.func_smooth_l1_loss.rst b/docs/api/api_python/ops/mindspore.ops.func_smooth_l1_loss.rst index 1766482be3b..b26994235cc 100644 --- a/docs/api/api_python/ops/mindspore.ops.func_smooth_l1_loss.rst +++ b/docs/api/api_python/ops/mindspore.ops.func_smooth_l1_loss.rst @@ -27,11 +27,8 @@ mindspore.ops.smooth_l1_loss 其中, :math:`\text{beta}` 控制损失函数从二次元变为线性的point。默认值是1.0。 :math:`N` 为batch size。 - .. note:: - 在Ascend上,目前不支持 `logits` 的数据类型是float64。 - 参数: - - **logits** (Tensor) - shape: :math:`(N, *)` ,其中 :math:`*` 表示任意数量的附加维度。数据类型为float16或float32, CPU和GPU后端还支持float64。 + - **logits** (Tensor) - shape: :math:`(N, *)` ,其中 :math:`*` 表示任意数量的附加维度。数据类型为float16,float32和float64。 - **labels** (Tensor) - shape: :math:`(N, *)` ,与 `logits` 的shape和数据类型相同。 - **beta** (float) - 控制损失函数在L1Loss和L2Loss间变换的阈值。默认值:1.0。 - **reduction** (str) - 缩减输出的方法。默认值:'none'。其他选项:'mean'和'sum'。 @@ -45,4 +42,3 @@ mindspore.ops.smooth_l1_loss - **TypeError** - `logits` 或 `labels` 的数据类型不是float16,float32和float64中的任一者。 - **ValueError** - `beta` 小于0。 - **ValueError** - `logits` 与 `labels` 的shape不同。 - - **TypeError** - Ascend后端不支持数据类型是float64的 `logits` 输入。 diff --git a/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_util.h b/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_util.h index 052ca213cf8..924fd299716 100644 --- a/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_util.h +++ b/mindspore/ccsrc/plugin/device/ascend/kernel/aicpu/aicpu_util.h @@ -188,6 +188,8 @@ constexpr auto kSliceGrad = "SliceGrad"; constexpr auto kStatelessDropOutGenMask = "StatelessDropOutGenMask"; constexpr auto kRaggedTensorToTensor = "RaggedTensorToTensor"; constexpr auto kAdaptiveMaxPool3D = "AdaptiveMaxPool3D"; +constexpr auto kSmoothL1Loss = "SmoothL1Loss"; +constexpr auto kSmoothL1LossGrad = "SmoothL1LossGrad"; const std::set kCpuKernelOps{kIdentity, kMaskedSelect, @@ -311,6 +313,8 @@ const std::map kOpNameToAicpuOpNameMap{ {kSampleDistortedBoundingBoxV2, "SampleDistortedBoundingBoxExt2"}, {kSparseSoftmaxCrossEntropyWithLogitsV2, "SparseSoftmaxCrossEntropyWithLogits"}, {kSparseToDenseV2, "SparseToDense"}, + {kSmoothL1Loss, "SmoothL1LossV2"}, + {kSmoothL1LossGrad, "SmoothL1LossGradV2"}, {kAvgPoolV1, "AvgPool"}, {kNonZero, "Where"}, {kAvgPoolGradV1, "AvgPoolGrad"}, diff --git a/mindspore/core/ops/grad/smooth_l1_loss_grad.cc b/mindspore/core/ops/grad/smooth_l1_loss_grad.cc index 05f7da0fb6e..17a8748b28e 100644 --- a/mindspore/core/ops/grad/smooth_l1_loss_grad.cc +++ b/mindspore/core/ops/grad/smooth_l1_loss_grad.cc @@ -50,10 +50,7 @@ std::string SmoothL1LossGrad::get_reduction() const { namespace { abstract::ShapePtr SmoothL1LossGradInferShape(const PrimitivePtr &primitive, const std::vector &input_args) { - MS_EXCEPTION_IF_NULL(primitive); auto prim_name = primitive->name(); - const int64_t input_num = 3; - CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, prim_name); auto prediction = CheckAndConvertUtils::CheckArgs(prim_name, input_args, kInputIndex0); auto target = CheckAndConvertUtils::CheckArgs(prim_name, input_args, kInputIndex1); abstract::CheckShapeSame(prim_name, prediction, target); @@ -76,17 +73,18 @@ TypePtr SmoothL1LossGradInferType(const PrimitivePtr &prim, const std::vector args; (void)args.emplace("prediction", input_args[kInputIndex0]->BuildType()); (void)args.emplace("target", input_args[kInputIndex1]->BuildType()); - auto dloss_type = CheckAndConvertUtils::CheckTensorTypeSame(args, valid_types, prim->name()); - return dloss_type; + (void)CheckAndConvertUtils::CheckTensorTypeSame(args, valid_types, prim->name()); + return input_args[kInputIndex0]->BuildType(); } } // namespace MIND_API_OPERATOR_IMPL(SmoothL1LossGrad, BaseOperator); AbstractBasePtr SmoothL1LossGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, const std::vector &input_args) { - for (auto item : input_args) { - MS_EXCEPTION_IF_NULL(item); - } + MS_EXCEPTION_IF_NULL(primitive); + auto prim_name = primitive->name(); + const int64_t input_num = 3; + CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, prim_name); auto infer_type = SmoothL1LossGradInferType(primitive, input_args); auto infer_shape = SmoothL1LossGradInferShape(primitive, input_args); return abstract::MakeAbstract(infer_shape, infer_type); diff --git a/mindspore/core/ops/smooth_l1_loss.cc b/mindspore/core/ops/smooth_l1_loss.cc index 23252b640cc..3bd1b5086e0 100644 --- a/mindspore/core/ops/smooth_l1_loss.cc +++ b/mindspore/core/ops/smooth_l1_loss.cc @@ -50,10 +50,7 @@ std::string SmoothL1Loss::get_reduction() const { namespace { abstract::ShapePtr SmoothL1LossInferShape(const PrimitivePtr &primitive, const std::vector &input_args) { - MS_EXCEPTION_IF_NULL(primitive); auto prim_name = primitive->name(); - const int64_t input_num = 2; - CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, prim_name); auto prediction = CheckAndConvertUtils::CheckArgs(prim_name, input_args, kInputIndex0); auto target = CheckAndConvertUtils::CheckArgs(prim_name, input_args, kInputIndex1); auto prediction_shape = prediction->shape(); @@ -69,33 +66,29 @@ abstract::ShapePtr SmoothL1LossInferShape(const PrimitivePtr &primitive, if (reduction == kNone) { return prediction_shape; } else { - ShapeVector shape_out{1}; + ShapeVector shape_out{}; return std::make_shared(shape_out); } } TypePtr SmoothL1LossInferType(const PrimitivePtr &prim, const std::vector &input_args) { // Infer type - std::set valid_types{}; - auto context = MsContext::GetInstance(); - MS_EXCEPTION_IF_NULL(context); - bool is_ascend = (context->get_param(MS_CTX_DEVICE_TARGET) == kAscendDevice); - if (is_ascend) { - valid_types = {kFloat16, kFloat32}; - } else { - valid_types = {kFloat16, kFloat32, kFloat64}; - } + const std::set valid_types = {kFloat16, kFloat32, kFloat64}; std::map args; (void)args.emplace("scale", input_args[kInputIndex0]->BuildType()); (void)args.emplace("bias", input_args[kInputIndex1]->BuildType()); - auto prediction_type = CheckAndConvertUtils::CheckTensorTypeSame(args, valid_types, prim->name()); - return prediction_type; + (void)CheckAndConvertUtils::CheckTensorTypeSame(args, valid_types, prim->name()); + return input_args[kInputIndex0]->BuildType(); } } // namespace AbstractBasePtr SmoothL1LossInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, const std::vector &input_args) { + MS_EXCEPTION_IF_NULL(primitive); + auto prim_name = primitive->name(); + const int64_t input_num = 2; + CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, prim_name); auto infer_type = SmoothL1LossInferType(primitive, input_args); auto infer_shape = SmoothL1LossInferShape(primitive, input_args); return abstract::MakeAbstract(infer_shape, infer_type); diff --git a/mindspore/python/mindspore/ops/_op_impl/aicpu/smooth_l1_loss.py b/mindspore/python/mindspore/ops/_op_impl/aicpu/smooth_l1_loss.py new file mode 100644 index 00000000000..f487782c5ef --- /dev/null +++ b/mindspore/python/mindspore/ops/_op_impl/aicpu/smooth_l1_loss.py @@ -0,0 +1,35 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""SmoothL1Loss op""" +from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType + +smooth_l1_loss_op_info = AiCPURegOp("SmoothL1Loss") \ + .fusion_type("OPAQUE") \ + .attr("sigma", "float") \ + .attr("reduction", "str") \ + .input(0, "prediction", "required") \ + .input(1, "target", "required") \ + .output(0, "output", "required") \ + .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \ + .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ + .dtype_format(DataType.F64_Default, DataType.F64_Default, DataType.F64_Default) \ + .get_op_info() + + +@op_info_register(smooth_l1_loss_op_info) +def _smooth_l1_loss_aicpu(): + """SmoothL1Loss AiCPU register""" + return diff --git a/mindspore/python/mindspore/ops/_op_impl/aicpu/smooth_l1_loss_grad.py b/mindspore/python/mindspore/ops/_op_impl/aicpu/smooth_l1_loss_grad.py new file mode 100644 index 00000000000..a1587866b0d --- /dev/null +++ b/mindspore/python/mindspore/ops/_op_impl/aicpu/smooth_l1_loss_grad.py @@ -0,0 +1,37 @@ +# Copyright 2022 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""SmoothL1LossGrad op""" +from mindspore.ops.op_info_register import op_info_register, AiCPURegOp, DataType + +smooth_l1_loss_grad_op_info = AiCPURegOp("SmoothL1LossGrad") \ + .fusion_type("OPAQUE") \ + .attr("sigma", "float") \ + .attr("reduction", "str") \ + .input(0, "prediction", "required") \ + .input(1, "target", "required") \ + .input(2, "dout", "required") \ + .output(0, "output", "required") \ + .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \ + .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ + .dtype_format(DataType.F64_Default, DataType.F64_Default, DataType.F64_Default, DataType.F64_Default) \ + .get_op_info() + + +@op_info_register(smooth_l1_loss_grad_op_info) +def _smooth_l1_loss_grad_aicpu(): + """SmoothL1LossGrad AiCPU register""" + return + \ No newline at end of file diff --git a/mindspore/python/mindspore/ops/function/nn_func.py b/mindspore/python/mindspore/ops/function/nn_func.py index a1c40f62e6f..98050167a59 100644 --- a/mindspore/python/mindspore/ops/function/nn_func.py +++ b/mindspore/python/mindspore/ops/function/nn_func.py @@ -2946,9 +2946,6 @@ def smooth_l1_loss(logits, labels, beta=1.0, reduction='none'): Here :math:`\text{beta}` controls the point where the loss function changes from quadratic to linear. Its default value is 1.0. :math:`N` is the batch size. - Note: - For Ascend platform, the float64 data type of `logits` is not support now. - Args: logits (Tensor): Tensor of shape :math:`(N, *)` where :math:`*` means, any number of additional dimensions. labels (Tensor): Ground truth data, tensor of shape :math:`(N, *)`, same shape and dtype as the `logits`. @@ -2963,10 +2960,9 @@ def smooth_l1_loss(logits, labels, beta=1.0, reduction='none'): Raises: TypeError: If `beta` is not a float. ValueError: If `reduction` is not one of 'none', 'mean', 'sum'. - TypeError: If dtype of `logits` or `labels` is neither float16 nor float32. + TypeError: If dtype of `logits` or `labels` is not one of float16, float32, float64. ValueError: If `beta` is less than or equal to 0. ValueError: If shape of `logits` is not the same as `labels`. - TypeError: The float64 data type of `logits` is support on Ascend platform. Supported Platforms: ``Ascend`` ``GPU`` ``CPU`` diff --git a/mindspore/python/mindspore/ops/operations/_grad_ops.py b/mindspore/python/mindspore/ops/operations/_grad_ops.py index 83f346162a6..d108020a489 100644 --- a/mindspore/python/mindspore/ops/operations/_grad_ops.py +++ b/mindspore/python/mindspore/ops/operations/_grad_ops.py @@ -2129,6 +2129,7 @@ class SmoothL1LossGrad(Primitive): @prim_attr_register def __init__(self, beta=1.0, reduction='none'): + self.add_prim_attr('sigma', self.beta) self.reduction = validator.check_string( reduction, ['none', 'sum', 'mean'], 'reduction', self.name) diff --git a/mindspore/python/mindspore/ops/operations/nn_ops.py b/mindspore/python/mindspore/ops/operations/nn_ops.py index fdab8819b80..1f9ccd83fc2 100644 --- a/mindspore/python/mindspore/ops/operations/nn_ops.py +++ b/mindspore/python/mindspore/ops/operations/nn_ops.py @@ -2993,6 +2993,7 @@ class SmoothL1Loss(Primitive): validator.check('beta', beta, '', 0, Rel.GT, self.name) validator.check_string( reduction, ['none', 'sum', 'mean'], 'reduction', self.name) + self.add_prim_attr('sigma', self.beta) self.init_prim_io_names(inputs=['prediction', 'target'], outputs=['output'])