!25906 [feat][assistant][I48O9S] add dynamic shape for SmoothL1Loss and SmoothL1LossGrad operations

Merge pull request !25906 from bubb1e/smooth_l1_loss
This commit is contained in:
i-robot 2021-12-18 06:53:21 +00:00 committed by Gitee
commit 727636d8fb
9 changed files with 150 additions and 49 deletions

View File

@ -34,21 +34,26 @@ float SmoothL1LossGrad::get_beta() const {
return GetValue<int32_t>(value_ptr);
}
AbstractBasePtr SmoothL1LossGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
namespace {
abstract::ShapePtr SmoothL1LossGradInferShape(const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
const int64_t input_num = 3;
(void)CheckAndConvertUtils::CheckInteger("smooth_l1_loss_grad_infer", SizeToLong(input_args.size()), kEqual,
input_num, prim_name);
// Infer shape
auto prediction = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex0]->BuildShape())[kShape];
auto target = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex1]->BuildShape())[kShape];
auto dloss = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[kInputIndex2]->BuildShape())[kShape];
CheckAndConvertUtils::Check("prediction shape", prediction, kEqual, "target shape", target, prim_name, TypeError);
CheckAndConvertUtils::Check("prediction shape", prediction, kEqual, "dloss", dloss, prim_name, TypeError);
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, prim_name);
auto prediction = CheckAndConvertUtils::CheckArgs<abstract::AbstractTensor>(prim_name, input_args, kInputIndex0);
auto target = CheckAndConvertUtils::CheckArgs<abstract::AbstractTensor>(prim_name, input_args, kInputIndex1);
auto dloss = CheckAndConvertUtils::CheckArgs<abstract::AbstractTensor>(prim_name, input_args, kInputIndex2);
abstract::CheckShapeSame(prim_name, prediction, target);
abstract::CheckShapeSame(prim_name, prediction, dloss);
auto x = input_args[kInputIndex0]->BuildShape();
MS_EXCEPTION_IF_NULL(x);
auto shape_element = x->cast<abstract::ShapePtr>();
MS_EXCEPTION_IF_NULL(shape_element);
return shape_element;
}
TypePtr SmoothL1LossGradInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
// Infer type
const std::set<TypePtr> valid_types = {kBool, kInt, kInt8, kInt16, kInt32, kInt64, kUInt, kUInt8,
kUInt16, kUInt32, kUInt64, kFloat, kFloat16, kFloat32, kFloat64, kComplex64};
@ -56,10 +61,17 @@ AbstractBasePtr SmoothL1LossGradInfer(const abstract::AnalysisEnginePtr &, const
(void)args.emplace("prediction", input_args[kInputIndex0]->BuildType());
(void)args.emplace("target", input_args[kInputIndex1]->BuildType());
(void)args.emplace("dloss", input_args[kInputIndex2]->BuildType());
auto dloss_type = CheckAndConvertUtils::CheckTensorTypeSame(args, valid_types, prim_name);
return std::make_shared<abstract::AbstractTensor>(dloss_type, prediction);
auto dloss_type = CheckAndConvertUtils::CheckTensorTypeSame(args, valid_types, prim->name());
return dloss_type;
}
REGISTER_PRIMITIVE_C(kNameSmoothL1LossGrad, SmoothL1LossGrad);
} // namespace
AbstractBasePtr SmoothL1LossGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
auto infer_type = SmoothL1LossGradInferType(primitive, input_args);
auto infer_shape = SmoothL1LossGradInferShape(primitive, input_args);
return abstract::MakeAbstract(infer_shape, infer_type);
}
REGISTER_PRIMITIVE_EVAL_IMPL(SmoothL1LossGrad, prim::kPrimSmoothL1LossGrad, SmoothL1LossGradInfer, nullptr, true);
} // namespace ops
} // namespace mindspore

View File

@ -38,6 +38,7 @@ class MS_CORE_API SmoothL1LossGrad : public PrimitiveC {
};
AbstractBasePtr SmoothL1LossGradInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args);
using kPrimSmoothL1LossGradPtr = std::shared_ptr<SmoothL1LossGrad>;
} // namespace ops
} // namespace mindspore
#endif // MINDSPORE_CORE_OPS_SMOOTH_L1_LOSS_GRAD_H_

View File

@ -33,27 +33,40 @@ float SmoothL1Loss::get_beta() const {
return GetValue<int32_t>(value_ptr);
}
AbstractBasePtr SmoothL1LossInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
namespace {
abstract::ShapePtr SmoothL1LossInferShape(const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
MS_EXCEPTION_IF_NULL(primitive);
auto prim_name = primitive->name();
const int64_t input_num = 2;
CheckAndConvertUtils::CheckInputArgs(input_args, kEqual, input_num, prim_name);
auto prediction = CheckAndConvertUtils::CheckArgs<abstract::AbstractTensor>(prim_name, input_args, kInputIndex0);
auto target = CheckAndConvertUtils::CheckArgs<abstract::AbstractTensor>(prim_name, input_args, kInputIndex1);
abstract::CheckShapeSame(prim_name, prediction, target);
auto x = input_args[kInputIndex0]->BuildShape();
MS_EXCEPTION_IF_NULL(x);
auto shape_element = x->cast<abstract::ShapePtr>();
MS_EXCEPTION_IF_NULL(shape_element);
return shape_element;
}
// Infer shape
auto prediction = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
auto target = CheckAndConvertUtils::ConvertShapePtrToShapeMap(input_args[0]->BuildShape())[kShape];
CheckAndConvertUtils::Check("prediction shape", prediction, kEqual, "target shape", target, prim_name, TypeError);
TypePtr SmoothL1LossInferType(const PrimitivePtr &prim, const std::vector<AbstractBasePtr> &input_args) {
// Infer type
const std::set<TypePtr> valid_types = {kFloat16, kFloat32};
std::map<std::string, TypePtr> args;
args.emplace("scale", input_args[0]->BuildType());
args.emplace("bias", input_args[1]->BuildType());
auto prediction_type = CheckAndConvertUtils::CheckTensorTypeSame(args, valid_types, prim_name);
return std::make_shared<abstract::AbstractTensor>(prediction_type, prediction);
(void)args.emplace("scale", input_args[kInputIndex0]->BuildType());
(void)args.emplace("bias", input_args[kInputIndex1]->BuildType());
auto prediction_type = CheckAndConvertUtils::CheckTensorTypeSame(args, valid_types, prim->name());
return prediction_type;
}
REGISTER_PRIMITIVE_C(kNameSmoothL1Loss, SmoothL1Loss);
} // namespace
AbstractBasePtr SmoothL1LossInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args) {
auto infer_type = SmoothL1LossInferType(primitive, input_args);
auto infer_shape = SmoothL1LossInferShape(primitive, input_args);
return abstract::MakeAbstract(infer_shape, infer_type);
}
REGISTER_PRIMITIVE_EVAL_IMPL(SmoothL1Loss, prim::kPrimSmoothL1Loss, SmoothL1LossInfer, nullptr, true);
} // namespace ops
} // namespace mindspore

View File

@ -45,6 +45,7 @@ class MS_CORE_API SmoothL1Loss : public PrimitiveC {
};
AbstractBasePtr SmoothL1LossInfer(const abstract::AnalysisEnginePtr &, const PrimitivePtr &primitive,
const std::vector<AbstractBasePtr> &input_args);
using kPrimSmoothL1LossPtr = std::shared_ptr<SmoothL1Loss>;
} // namespace ops
} // namespace mindspore

View File

@ -282,7 +282,9 @@ from .pad_d_ds import _pad_d_ds_tbe
from .arg_max_with_value import _arg_max_with_value_tbe
from .arg_min_with_value import _arg_min_with_value_tbe
from .smooth_l1_loss import _smooth_l1_loss_tbe
from .smooth_l1_loss_ds import _smooth_l1_loss_ds_tbe
from .smooth_l1_loss_grad import _smooth_l1_loss_grad_tbe
from .smooth_l1_loss_grad_ds import _smooth_l1_loss_grad_ds_tbe
from .soft_margin_loss import _soft_margin_loss_tbe
from .soft_margin_loss_grad import _soft_margin_loss_grad_tbe
from .fused_mul_add import _fused_mul_add_tbe

View File

@ -0,0 +1,45 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""SmoothL1Loss op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
smooth_l1_loss_op_info = TBERegOp("SmoothL1Loss") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("smooth_l1_loss.so") \
.compute_cost(10) \
.kernel_name("smooth_l1_loss") \
.partial_flag(True) \
.dynamic_shape(True) \
.attr("beta", "required", "float", "all") \
.input(0, "predict", False, "required", "all") \
.input(1, "label", False, "required", "all") \
.output(0, "loss", False, "required", "all") \
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD) \
.dtype_format(DataType.F16_FracZ, DataType.F16_FracZ, DataType.F16_FracZ) \
.dtype_format(DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \
.dtype_format(DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ) \
.dtype_format(DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0) \
.get_op_info()
@op_info_register(smooth_l1_loss_op_info)
def _smooth_l1_loss_ds_tbe():
"""SmoothL1Loss TBE register"""
return

View File

@ -0,0 +1,46 @@
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""SmoothL1LossGrad op"""
from mindspore.ops.op_info_register import op_info_register, TBERegOp, DataType
smooth_l1_loss_grad_op_info = TBERegOp("SmoothL1LossGrad") \
.fusion_type("OPAQUE") \
.async_flag(False) \
.binfile_name("smooth_l1_loss_grad.so") \
.compute_cost(10) \
.kernel_name("smooth_l1_loss_grad") \
.partial_flag(True) \
.dynamic_shape(True) \
.attr("beta", "required", "float", "all") \
.input(0, "predict", False, "required", "all") \
.input(1, "label", False, "required", "all") \
.input(2, "dout", False, "required", "all") \
.output(0, "loss", False, "required", "all") \
.dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
.dtype_format(DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD, DataType.F16_5HD) \
.dtype_format(DataType.F16_FracZ, DataType.F16_FracZ, DataType.F16_FracZ, DataType.F16_FracZ) \
.dtype_format(DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0, DataType.F16_C1HWNCoC0) \
.dtype_format(DataType.F32_Default, DataType.F32_Default, DataType.F32_Default, DataType.F32_Default) \
.dtype_format(DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD, DataType.F32_5HD) \
.dtype_format(DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ, DataType.F32_FracZ) \
.dtype_format(DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0, DataType.F32_C1HWNCoC0) \
.get_op_info()
@op_info_register(smooth_l1_loss_grad_op_info)
def _smooth_l1_loss_grad_ds_tbe():
"""SmoothL1LossGrad TBE register"""
return

View File

@ -1822,23 +1822,13 @@ class NLLLossGrad(PrimitiveWithInfer):
return x_dtype
class SmoothL1LossGrad(PrimitiveWithInfer):
class SmoothL1LossGrad(Primitive):
"""Computes gradient for prediction on SmoothL1Loss."""
@prim_attr_register
def __init__(self, beta=1.0):
pass
def infer_shape(self, prediction, target, dloss):
validator.check('prediction shape', prediction, 'target shape', target, Rel.EQ, self.name)
validator.check('prediction shape', prediction, 'dloss shape', dloss, Rel.EQ, self.name)
return prediction
def infer_dtype(self, prediction, target, dloss):
args = {"prediction": prediction, "target": target, 'dloss': dloss}
validator.check_tensors_dtypes_same_and_valid(args, mstype.number_type, self.name)
return dloss
class SoftMarginLossGrad(Primitive):
"""Computes gradient for prediction on SoftMarginLoss."""

View File

@ -2604,7 +2604,7 @@ class ApplyMomentum(Primitive):
self.add_prim_attr('side_effect_mem', True)
class SmoothL1Loss(PrimitiveWithInfer):
class SmoothL1Loss(Primitive):
r"""
Computes smooth L1 loss, a robust L1 loss.
@ -2667,15 +2667,6 @@ class SmoothL1Loss(PrimitiveWithInfer):
validator.check('beta', beta, '', 0, Rel.GT, self.name)
self.init_prim_io_names(inputs=['prediction', 'target'], outputs=['output'])
def infer_shape(self, prediction, target):
validator.check('prediction shape', prediction, 'target shape', target, Rel.EQ, self.name)
return prediction
def infer_dtype(self, prediction, target):
args = {"prediction": prediction, "target": target}
validator.check_tensors_dtypes_same_and_valid(args, (mstype.float16, mstype.float32), self.name)
return prediction
class SoftMarginLoss(Primitive):
r"""