diff --git a/mindspore/core/ops/grad/smooth_l1_loss_grad.cc b/mindspore/core/ops/grad/smooth_l1_loss_grad.cc index 7572fe52d3c..b863b86ed80 100644 --- a/mindspore/core/ops/grad/smooth_l1_loss_grad.cc +++ b/mindspore/core/ops/grad/smooth_l1_loss_grad.cc @@ -34,21 +34,26 @@ float SmoothL1LossGrad::get_beta() const { return GetValue(value_ptr); } -AbstractBasePtr SmoothL1LossGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, - const std::vector &input_args) { +namespace { +abstract::ShapePtr SmoothL1LossGradInferShape(const PrimitivePtr &primitive, + const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); auto prim_name = primitive->name(); const int64_t input_num = 3; - (void)CheckAndConvertUtils::CheckInteger("smooth_l1_loss_grad_infer", SizeToLong(input_args.size()), kEqual, - input_num, prim_name); - - // Infer shape - auto prediction = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->BuildShape())[kShape]; - auto target = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->BuildShape())[kShape]; - auto dloss = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex2]->BuildShape())[kShape]; - CheckAndConvertUtils::Check("prediction shape", prediction, kEqual, "target shape", target, prim_name, TypeError); - CheckAndConvertUtils::Check("prediction shape", prediction, kEqual, "dloss", dloss, prim_name, TypeError); + CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, prim_name); + auto prediction = CheckAndConvertUtils::CheckArgs(prim_name, input_args, kInputIndex0); + auto target = CheckAndConvertUtils::CheckArgs(prim_name, input_args, kInputIndex1); + auto dloss = CheckAndConvertUtils::CheckArgs(prim_name, input_args, kInputIndex2); + abstract::CheckShapeSame(prim_name, prediction, target); + abstract::CheckShapeSame(prim_name, prediction, dloss); + auto x = input_args[kInputIndex0]->BuildShape(); + MS_EXCEPTION_IF_NULL(x); + auto shape_element = x->cast(); + MS_EXCEPTION_IF_NULL(shape_element); + return shape_element; +} +TypePtr SmoothL1LossGradInferType(const PrimitivePtr &prim, const std::vector &input_args) { // Infer type const std::set valid_types = {kBool, kInt, kInt8, kInt16, kInt32, kInt64, kUInt, kUInt8, kUInt16, kUInt32, kUInt64, kFloat, kFloat16, kFloat32, kFloat64, kComplex64}; @@ -56,10 +61,17 @@ AbstractBasePtr SmoothL1LossGradInfer(const abstract::AnalysisEnginePtr &, const (void)args.emplace("prediction", input_args[kInputIndex0]->BuildType()); (void)args.emplace("target", input_args[kInputIndex1]->BuildType()); (void)args.emplace("dloss", input_args[kInputIndex2]->BuildType()); - auto dloss_type = CheckAndConvertUtils::CheckTensorTypeSame(args, valid_types, prim_name); - - return std::make_shared(dloss_type, prediction); + auto dloss_type = CheckAndConvertUtils::CheckTensorTypeSame(args, valid_types, prim->name()); + return dloss_type; } -REGISTER_PRIMITIVE_C(kNameSmoothL1LossGrad, SmoothL1LossGrad); +} // namespace + +AbstractBasePtr SmoothL1LossGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, + const std::vector &input_args) { + auto infer_type = SmoothL1LossGradInferType(primitive, input_args); + auto infer_shape = SmoothL1LossGradInferShape(primitive, input_args); + return abstract::MakeAbstract(infer_shape, infer_type); +} +REGISTER_PRIMITIVE_EVAL_IMPL(SmoothL1LossGrad, prim::kPrimSmoothL1LossGrad, SmoothL1LossGradInfer, nullptr, true); } // namespace ops } // namespace mindspore diff --git a/mindspore/core/ops/grad/smooth_l1_loss_grad.h b/mindspore/core/ops/grad/smooth_l1_loss_grad.h index d638f51d4d4..71bdec5fc0a 100644 --- a/mindspore/core/ops/grad/smooth_l1_loss_grad.h +++ b/mindspore/core/ops/grad/smooth_l1_loss_grad.h @@ -38,6 +38,7 @@ class MS_CORE_API SmoothL1LossGrad : public PrimitiveC { }; AbstractBasePtr SmoothL1LossGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, const std::vector &input_args); +using kPrimSmoothL1LossGradPtr = std::shared_ptr; } // namespace ops } // namespace mindspore #endif // MINDSPORE_CORE_OPS_SMOOTH_L1_LOSS_GRAD_H_ diff --git a/mindspore/core/ops/smooth_l1_loss.cc b/mindspore/core/ops/smooth_l1_loss.cc index 0a843686521..09a831be31a 100644 --- a/mindspore/core/ops/smooth_l1_loss.cc +++ b/mindspore/core/ops/smooth_l1_loss.cc @@ -33,27 +33,40 @@ float SmoothL1Loss::get_beta() const { return GetValue(value_ptr); } -AbstractBasePtr SmoothL1LossInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, - const std::vector &input_args) { +namespace { +abstract::ShapePtr SmoothL1LossInferShape(const PrimitivePtr &primitive, + const std::vector &input_args) { MS_EXCEPTION_IF_NULL(primitive); auto prim_name = primitive->name(); const int64_t input_num = 2; CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, prim_name); + auto prediction = CheckAndConvertUtils::CheckArgs(prim_name, input_args, kInputIndex0); + auto target = CheckAndConvertUtils::CheckArgs(prim_name, input_args, kInputIndex1); + abstract::CheckShapeSame(prim_name, prediction, target); + auto x = input_args[kInputIndex0]->BuildShape(); + MS_EXCEPTION_IF_NULL(x); + auto shape_element = x->cast(); + MS_EXCEPTION_IF_NULL(shape_element); + return shape_element; +} - // Infer shape - auto prediction = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; - auto target = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape]; - CheckAndConvertUtils::Check("prediction shape", prediction, kEqual, "target shape", target, prim_name, TypeError); - +TypePtr SmoothL1LossInferType(const PrimitivePtr &prim, const std::vector &input_args) { // Infer type const std::set valid_types = {kFloat16, kFloat32}; std::map args; - args.emplace("scale", input_args[0]->BuildType()); - args.emplace("bias", input_args[1]->BuildType()); - auto prediction_type = CheckAndConvertUtils::CheckTensorTypeSame(args, valid_types, prim_name); - - return std::make_shared(prediction_type, prediction); + (void)args.emplace("scale", input_args[kInputIndex0]->BuildType()); + (void)args.emplace("bias", input_args[kInputIndex1]->BuildType()); + auto prediction_type = CheckAndConvertUtils::CheckTensorTypeSame(args, valid_types, prim->name()); + return prediction_type; } -REGISTER_PRIMITIVE_C(kNameSmoothL1Loss, SmoothL1Loss); +} // namespace + +AbstractBasePtr SmoothL1LossInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, + const std::vector &input_args) { + auto infer_type = SmoothL1LossInferType(primitive, input_args); + auto infer_shape = SmoothL1LossInferShape(primitive, input_args); + return abstract::MakeAbstract(infer_shape, infer_type); +} +REGISTER_PRIMITIVE_EVAL_IMPL(SmoothL1Loss, prim::kPrimSmoothL1Loss, SmoothL1LossInfer, nullptr, true); } // namespace ops } // namespace mindspore diff --git a/mindspore/core/ops/smooth_l1_loss.h b/mindspore/core/ops/smooth_l1_loss.h index 16f1ce9fc41..d0faf93bf61 100644 --- a/mindspore/core/ops/smooth_l1_loss.h +++ b/mindspore/core/ops/smooth_l1_loss.h @@ -45,6 +45,7 @@ class MS_CORE_API SmoothL1Loss : public PrimitiveC { }; AbstractBasePtr SmoothL1LossInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive, const std::vector &input_args); +using kPrimSmoothL1LossPtr = std::shared_ptr; } // namespace ops } // namespace mindspore diff --git a/mindspore/python/mindspore/ops/_op_impl/tbe/__init__.py b/mindspore/python/mindspore/ops/_op_impl/tbe/__init__.py index cafcac1d824..44ae641e784 100644 --- a/mindspore/python/mindspore/ops/_op_impl/tbe/__init__.py +++ b/mindspore/python/mindspore/ops/_op_impl/tbe/__init__.py @@ -282,7 +282,9 @@ from .pad_d_ds import _pad_d_ds_tbe from .arg_max_with_value import _arg_max_with_value_tbe from .arg_min_with_value import _arg_min_with_value_tbe from .smooth_l1_loss import _smooth_l1_loss_tbe +from .smooth_l1_loss_ds import _smooth_l1_loss_ds_tbe from .smooth_l1_loss_grad import _smooth_l1_loss_grad_tbe +from .smooth_l1_loss_grad_ds import _smooth_l1_loss_grad_ds_tbe from .soft_margin_loss import _soft_margin_loss_tbe from .soft_margin_loss_grad import _soft_margin_loss_grad_tbe from .fused_mul_add import _fused_mul_add_tbe diff --git a/mindspore/python/mindspore/ops/_op_impl/tbe/smooth_l1_loss_ds.py b/mindspore/python/mindspore/ops/_op_impl/tbe/smooth_l1_loss_ds.py new file mode 100644 index 00000000000..94ee5daa4b4 --- /dev/null +++ b/mindspore/python/mindspore/ops/_op_impl/tbe/smooth_l1_loss_ds.py @@ -0,0 +1,45 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""SmoothL1Loss op""" +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType + +smooth_l1_loss_op_info = TBERegOp("SmoothL1Loss") \ + .fusion_type("OPAQUE") \ + .async_flag(False) \ + .binfile_name("smooth_l1_loss.so") \ + .compute_cost(10) \ + .kernel_name("smooth_l1_loss") \ + .partial_flag(True) \ + .dynamic_shape(True) \ + .attr("beta", "required", "float", "all") \ + .input(0, "predict", False, "required", "all") \ + .input(1, "label", False, "required", "all") \ + .output(0, "loss", False, "required", "all") \ + .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \ + .dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD) \ + .dtype_format(DataType.F16_FracZ, DataType.F16_FracZ, DataType.F16_FracZ) \ + .dtype_format(DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0) \ + .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ + .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \ + .dtype_format(DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ) \ + .dtype_format(DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0) \ + .get_op_info() + + +@op_info_register(smooth_l1_loss_op_info) +def _smooth_l1_loss_ds_tbe(): + """SmoothL1Loss TBE register""" + return diff --git a/mindspore/python/mindspore/ops/_op_impl/tbe/smooth_l1_loss_grad_ds.py b/mindspore/python/mindspore/ops/_op_impl/tbe/smooth_l1_loss_grad_ds.py new file mode 100644 index 00000000000..f055b2f9ef5 --- /dev/null +++ b/mindspore/python/mindspore/ops/_op_impl/tbe/smooth_l1_loss_grad_ds.py @@ -0,0 +1,46 @@ +# Copyright 2021 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +"""SmoothL1LossGrad op""" +from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType + +smooth_l1_loss_grad_op_info = TBERegOp("SmoothL1LossGrad") \ + .fusion_type("OPAQUE") \ + .async_flag(False) \ + .binfile_name("smooth_l1_loss_grad.so") \ + .compute_cost(10) \ + .kernel_name("smooth_l1_loss_grad") \ + .partial_flag(True) \ + .dynamic_shape(True) \ + .attr("beta", "required", "float", "all") \ + .input(0, "predict", False, "required", "all") \ + .input(1, "label", False, "required", "all") \ + .input(2, "dout", False, "required", "all") \ + .output(0, "loss", False, "required", "all") \ + .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \ + .dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD) \ + .dtype_format(DataType.F16_FracZ, DataType.F16_FracZ, DataType.F16_FracZ, DataType.F16_FracZ) \ + .dtype_format(DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0) \ + .dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \ + .dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \ + .dtype_format(DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ) \ + .dtype_format(DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0) \ + .get_op_info() + + +@op_info_register(smooth_l1_loss_grad_op_info) +def _smooth_l1_loss_grad_ds_tbe(): + """SmoothL1LossGrad TBE register""" + return diff --git a/mindspore/python/mindspore/ops/operations/_grad_ops.py b/mindspore/python/mindspore/ops/operations/_grad_ops.py index 3f972d0c4cb..42257fec732 100644 --- a/mindspore/python/mindspore/ops/operations/_grad_ops.py +++ b/mindspore/python/mindspore/ops/operations/_grad_ops.py @@ -1822,23 +1822,13 @@ class NLLLossGrad(PrimitiveWithInfer): return x_dtype -class SmoothL1LossGrad(PrimitiveWithInfer): +class SmoothL1LossGrad(Primitive): """Computes gradient for prediction on SmoothL1Loss.""" @prim_attr_register def __init__(self, beta=1.0): pass - def infer_shape(self, prediction, target, dloss): - validator.check('prediction shape', prediction, 'target shape', target, Rel.EQ, self.name) - validator.check('prediction shape', prediction, 'dloss shape', dloss, Rel.EQ, self.name) - return prediction - - def infer_dtype(self, prediction, target, dloss): - args = {"prediction": prediction, "target": target, 'dloss': dloss} - validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name) - return dloss - class SoftMarginLossGrad(Primitive): """Computes gradient for prediction on SoftMarginLoss.""" diff --git a/mindspore/python/mindspore/ops/operations/nn_ops.py b/mindspore/python/mindspore/ops/operations/nn_ops.py index d4a7d2a5072..666e5f13de1 100644 --- a/mindspore/python/mindspore/ops/operations/nn_ops.py +++ b/mindspore/python/mindspore/ops/operations/nn_ops.py @@ -2604,7 +2604,7 @@ class ApplyMomentum(Primitive): self.add_prim_attr('side_effect_mem', True) -class SmoothL1Loss(PrimitiveWithInfer): +class SmoothL1Loss(Primitive): r""" Computes smooth L1 loss, a robust L1 loss. @@ -2667,15 +2667,6 @@ class SmoothL1Loss(PrimitiveWithInfer): validator.check('beta', beta, '', 0, Rel.GT, self.name) self.init_prim_io_names(inputs=['prediction', 'target'], outputs=['output']) - def infer_shape(self, prediction, target): - validator.check('prediction shape', prediction, 'target shape', target, Rel.EQ, self.name) - return prediction - - def infer_dtype(self, prediction, target): - args = {"prediction": prediction, "target": target} - validator.check_tensors_dtypes_same_and_valid(args, (mstype.float16, mstype.float32), self.name) - return prediction - class SoftMarginLoss(Primitive): r"""